From ca4d091a3ffe499f52fc45898683a249a4360888 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 23:50:55 +0300 Subject: [PATCH 01/18] Optimize Falcon block for inference --- src/petals/models/falcon/block.py | 375 +++++++++++++++++++++++++++++- tests/test_optimized_layers.py | 33 +++ 2 files changed, 405 insertions(+), 3 deletions(-) create mode 100644 tests/test_optimized_layers.py diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index e677e0689..c2e2cbdb4 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -3,12 +3,380 @@ Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py See commit history for authorship. """ +import math +from functools import partial from typing import Optional, Tuple import torch -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconConfig, + FalconDecoderLayer, + FalconLinear, + FalconMLP, + FalconModel, + FalconRotaryEmbedding, + LayerNorm, + build_alibi_tensor, + dropout_add, + rotate_half, +) + KVCache = Tuple[torch.Tensor, torch.Tensor] +INFERENCE_MAX_LENGTH = 8192 + + +def apply_rotary(query, key, cos, sin): + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + + +class OptimizedFalconRotaryEmbedding(FalconRotaryEmbedding): + def __init__(self, head_dim: int, base=10000): + super().__init__(head_dim, base) + self.cuda_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_rotary(self, query, key, cos, sin): + if self.cuda_graph is None: + self.cuda_graph = torch.cuda.CUDAGraph() + self.input_surface = (query, key, cos, sin) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + apply_rotary(*self.input_surface) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.cuda_graph): + self.static_outputs = apply_rotary(*self.input_surface) + + inputs = (query, key, cos, sin) + for static_input, data in zip(self.input_surface, inputs): + static_input.copy_(data) + self.cuda_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + if self.seq_len_cached == -1: + # warm up the cache + super().cos_sin(1, INFERENCE_MAX_LENGTH - 1, device=device, dtype=dtype) + return super().cos_sin( + seq_len=seq_len, past_key_values_length=past_key_values_length, device=device, dtype=dtype + ) + + def forward(self, query, key, past_key_values_length=0): + batch, seq_len, head_dim = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) + if seq_len == 1 and torch.is_inference_mode_enabled(): + return self._optimized_apply_rotary(query, key, cos, sin) + else: + return apply_rotary(query, key, cos, sin) + + +def split_heads( + fused_qkv: torch.Tensor, num_heads, num_kv_heads, head_dim +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) + query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + + +class OptimizedFalconAttention(FalconAttention): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + assert config.new_decoder_architecture + assert config.attention_dropout == 0.0 + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + + self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias) + self.num_kv_heads = config.num_kv_heads + + self._split_heads = partial( + split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim + ) + self.qkv_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_qkv(self, hidden_states): + if self.qkv_graph is None: + self.qkv_graph = torch.cuda.CUDAGraph() + self.static_input = hidden_states + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + fused_qkv = self.query_key_value(hidden_states) + self._split_heads(fused_qkv) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.qkv_graph): + static_fused_qkv = self.query_key_value(hidden_states) + self.static_outputs = self._split_heads(static_fused_qkv) + + self.static_input.copy_(hidden_states) + self.qkv_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + assert alibi is None + assert not output_attentions + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): + query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) + else: + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + num_kv_heads = self.num_heads + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + attn_output = F.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=True + ) + + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + output_tensor = self.dense(attn_output) + + return output_tensor, present + + +class OptimizedFalconDecoderLayer(FalconDecoderLayer): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = OptimizedFalconAttention(config) + self.mlp = FalconMLP(config) + self.hidden_dropout = config.hidden_dropout + self.config = config + + assert config.new_decoder_architecture + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.ln_graph = None + self.static_input = None + + def _optimized_apply_ln(self, hidden_states): + if self.ln_graph is None: + self.ln_graph = torch.cuda.CUDAGraph() + self.static_input = hidden_states + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.ln_attn(hidden_states) + self.ln_mlp(hidden_states) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.ln_graph): + ln_attn_output = self.ln_attn(hidden_states) + ln_mlp_output = self.ln_mlp(hidden_states) + self.static_outputs = (ln_attn_output, ln_mlp_output) + + self.static_input.copy_(hidden_states) + self.ln_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): + attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + else: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + outputs = attn_outputs[1:] + + mlp_output = self.mlp(mlp_layernorm_out) + mlp_output += attention_output + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class _WrappedFalconBlock(OptimizedFalconDecoderLayer): + def __init__(self, config: FalconConfig): + super().__init__(config) + assert not self.config.alibi + + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + **kwargs, + ): + assert attention_mask is None + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=None, + alibi=None, + layer_past=layer_past, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, 1, self.config.num_kv_heads, seq_len, head_dim) + # Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1 + state = state.expand(-1, self.config.num_key_value_groups, -1, -1, -1) + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim) + state = state[:, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state class WrappedFalconBlock(FalconDecoderLayer): @@ -19,8 +387,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, layer_past: Optional[KVCache] = None, + layer_past: Optional[KVCache] = None, use_cache: bool = False, - **kwargs + **kwargs, ): batch_size, seq_length = hidden_states.shape[:2] @@ -41,7 +410,7 @@ def forward( alibi=alibi, layer_past=layer_past, use_cache=use_cache, - **kwargs + **kwargs, ) if use_cache: diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py new file mode 100644 index 000000000..5b99b5761 --- /dev/null +++ b/tests/test_optimized_layers.py @@ -0,0 +1,33 @@ +from petals.models.falcon.block import WrappedFalconBlock +from petals.server.from_pretrained import load_pretrained_block +from petals.utils.auto_config import AutoDistributedConfig +from petals.server.block_utils import resolve_block_dtype +from petals.utils.convert_block import QuantType, convert_block +import torch + + +def test_falcon(): + config = AutoDistributedConfig.from_pretrained("tiiuae/falcon-rw-1b") + config.alibi = False + config.new_decoder_architecture = True + + device = "cuda:0" + tensor_parallel_devices = (device,) + dtype = torch.bfloat16 + quant_type = QuantType.NONE + + block = config.block_class(config).to(dtype) + block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + unopt_block = WrappedFalconBlock(config).to(dtype) + unopt_block = convert_block( + unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + ) + + unopt_block.load_state_dict(block.state_dict()) + + for _ in range(3): + dummy_input = torch.randn(1, 1, config.hidden_size, device="cuda", dtype=dtype) + block_output = block(dummy_input) + unopt_block_output = unopt_block(dummy_input) + assert torch.allclose(block_output[0], unopt_block_output[0], atol=1e-6, rtol=0) From 1f006c59a1fda6d08ac0bfeaa1ba37105f1c50c8 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 23:58:13 +0300 Subject: [PATCH 02/18] Fix class names --- src/petals/models/falcon/block.py | 4 ++-- tests/test_optimized_layers.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index c2e2cbdb4..a68c0b940 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -298,7 +298,7 @@ def forward( return outputs # hidden_states, present, attentions -class _WrappedFalconBlock(OptimizedFalconDecoderLayer): +class WrappedFalconBlock(OptimizedFalconDecoderLayer): def __init__(self, config: FalconConfig): super().__init__(config) assert not self.config.alibi @@ -379,7 +379,7 @@ def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: return state -class WrappedFalconBlock(FalconDecoderLayer): +class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): def forward( self, hidden_states: torch.Tensor, diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 5b99b5761..2be80f491 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -1,4 +1,4 @@ -from petals.models.falcon.block import WrappedFalconBlock +from petals.models.falcon.block import UnoptimizedWrappedFalconBlock from petals.server.from_pretrained import load_pretrained_block from petals.utils.auto_config import AutoDistributedConfig from petals.server.block_utils import resolve_block_dtype @@ -11,7 +11,7 @@ def test_falcon(): config.alibi = False config.new_decoder_architecture = True - device = "cuda:0" + device = "cpu" tensor_parallel_devices = (device,) dtype = torch.bfloat16 quant_type = QuantType.NONE @@ -19,7 +19,7 @@ def test_falcon(): block = config.block_class(config).to(dtype) block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) - unopt_block = WrappedFalconBlock(config).to(dtype) + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) unopt_block = convert_block( unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True ) @@ -27,7 +27,7 @@ def test_falcon(): unopt_block.load_state_dict(block.state_dict()) for _ in range(3): - dummy_input = torch.randn(1, 1, config.hidden_size, device="cuda", dtype=dtype) + dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype) block_output = block(dummy_input) unopt_block_output = unopt_block(dummy_input) assert torch.allclose(block_output[0], unopt_block_output[0], atol=1e-6, rtol=0) From 1fc22bd69f5e312af83fe3ad3ed4fa2859a5723c Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 00:09:22 +0300 Subject: [PATCH 03/18] Post-rebase changes --- src/petals/models/falcon/block.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index a68c0b940..5d1cfa379 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -363,18 +363,17 @@ def _expand_states(self, state: torch.Tensor) -> torch.Tensor: batch_size_x_num_kv_heads, seq_len, head_dim = state.shape batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads - state = state.view(batch_size, 1, self.config.num_kv_heads, seq_len, head_dim) - # Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1 - state = state.expand(-1, self.config.num_key_value_groups, -1, -1, -1) - state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy return state def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: batch_size_x_num_attn_heads, seq_len, head_dim = state.shape batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads - state = state.view(batch_size, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim) - state = state[:, 0] + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) return state @@ -387,7 +386,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, layer_past: Optional[KVCache] = None, - layer_past: Optional[KVCache] = None, use_cache: bool = False, **kwargs, ): From 67764fea9e595e711d949105bcafade86553f0c4 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 00:16:10 +0300 Subject: [PATCH 04/18] Fix formatting, reduce diff --- src/petals/models/falcon/block.py | 6 +----- tests/test_optimized_layers.py | 5 +++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 5d1cfa379..5b045a7df 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -24,7 +24,6 @@ rotate_half, ) - KVCache = Tuple[torch.Tensor, torch.Tensor] INFERENCE_MAX_LENGTH = 8192 @@ -225,6 +224,7 @@ def __init__(self, config: FalconConfig): self.hidden_dropout = config.hidden_dropout self.config = config + assert not self.config.alibi assert config.new_decoder_architecture self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -299,10 +299,6 @@ def forward( class WrappedFalconBlock(OptimizedFalconDecoderLayer): - def __init__(self, config: FalconConfig): - super().__init__(config) - assert not self.config.alibi - def forward( self, hidden_states: torch.Tensor, diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 2be80f491..1a88f7dbf 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -1,9 +1,10 @@ +import torch + from petals.models.falcon.block import UnoptimizedWrappedFalconBlock +from petals.server.block_utils import resolve_block_dtype from petals.server.from_pretrained import load_pretrained_block from petals.utils.auto_config import AutoDistributedConfig -from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import QuantType, convert_block -import torch def test_falcon(): From 111bf7e1251cb3eba06b8a714d9312cea70ff4d6 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 00:47:11 +0300 Subject: [PATCH 05/18] Fix the test --- tests/test_optimized_layers.py | 93 ++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 1a88f7dbf..cb475e2db 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -1,12 +1,99 @@ +from typing import Optional, Tuple + +import pytest import torch +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor -from petals.models.falcon.block import UnoptimizedWrappedFalconBlock -from petals.server.block_utils import resolve_block_dtype -from petals.server.from_pretrained import load_pretrained_block from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs, + ): + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state + +@pytest.mark.forked def test_falcon(): config = AutoDistributedConfig.from_pretrained("tiiuae/falcon-rw-1b") config.alibi = False From ce401f116351858796878ee27a74467375010657 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 00:47:27 +0300 Subject: [PATCH 06/18] Make cos_cached/sin_cached buffers --- src/petals/models/falcon/block.py | 106 +++++------------------------- 1 file changed, 18 insertions(+), 88 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 5b045a7df..c857f6041 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -16,10 +16,8 @@ FalconDecoderLayer, FalconLinear, FalconMLP, - FalconModel, FalconRotaryEmbedding, LayerNorm, - build_alibi_tensor, dropout_add, rotate_half, ) @@ -61,11 +59,26 @@ def _optimized_apply_rotary(self, query, key, cos, sin): return tuple(o.detach() for o in self.static_outputs) def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length if self.seq_len_cached == -1: # warm up the cache - super().cos_sin(1, INFERENCE_MAX_LENGTH - 1, device=device, dtype=dtype) - return super().cos_sin( - seq_len=seq_len, past_key_values_length=past_key_values_length, device=device, dtype=dtype + total_length = max(INFERENCE_MAX_LENGTH, total_length) + + if total_length > self.seq_len_cached: + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype)) + self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype)) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], ) def forward(self, query, key, past_key_values_length=0): @@ -372,86 +385,3 @@ def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: state = state[:, :, 0] state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) return state - - -class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): - def forward( - self, - hidden_states: torch.Tensor, - *args, - attention_mask: Optional[torch.Tensor] = None, - alibi: Optional[torch.Tensor] = None, - layer_past: Optional[KVCache] = None, - use_cache: bool = False, - **kwargs, - ): - batch_size, seq_length = hidden_states.shape[:2] - - if layer_past is not None: - layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) - past_length = 0 if layer_past is None else layer_past[0].shape[1] - seq_length_with_past = seq_length + past_length - - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - if alibi is None and self.config.alibi: - alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) - - outputs = super().forward( - hidden_states, - *args, - attention_mask=attention_mask, - alibi=alibi, - layer_past=layer_past, - use_cache=use_cache, - **kwargs, - ) - - if use_cache: - present_key_value = outputs[-1] - present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) - outputs = outputs[:-1] + (present_key_value,) - - return outputs - - def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: - key_states, value_states = key_value - - key_states = key_states.permute(0, 2, 1) - assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] - - if self.config.new_decoder_architecture: - key_states = self._expand_states(key_states) - value_states = self._expand_states(value_states) - - return (key_states, value_states) - - def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: - key_states, value_states = key_value - - if self.config.new_decoder_architecture: - key_states = self._collapse_states(key_states) - value_states = self._collapse_states(value_states) - - assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] - key_states = key_states.permute(0, 2, 1) - - return (key_states, value_states) - - def _expand_states(self, state: torch.Tensor) -> torch.Tensor: - batch_size_x_num_kv_heads, seq_len, head_dim = state.shape - batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads - - state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) - state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy - state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy - return state - - def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: - batch_size_x_num_attn_heads, seq_len, head_dim = state.shape - batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads - - state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) - state = state[:, :, 0] - state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) - return state From 2c1452de5c91ff3c143c93277b1976c870f2b6e6 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 01:05:27 +0300 Subject: [PATCH 07/18] Fix buffer registration --- src/petals/models/falcon/block.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index c857f6041..d1caac022 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -16,7 +16,6 @@ FalconDecoderLayer, FalconLinear, FalconMLP, - FalconRotaryEmbedding, LayerNorm, dropout_add, rotate_half, @@ -30,9 +29,14 @@ def apply_rotary(query, key, cos, sin): return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) -class OptimizedFalconRotaryEmbedding(FalconRotaryEmbedding): +class OptimizedFalconRotaryEmbedding(nn.Module): def __init__(self, head_dim: int, base=10000): - super().__init__(head_dim, base) + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = -1 + self.cuda_graph = None self.input_surface = None self.static_outputs = None From ae30427276b28a88cada7cccee2ca5de355fda46 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 01:20:33 +0300 Subject: [PATCH 08/18] Make the block compatible with other architectures --- src/petals/models/falcon/block.py | 80 +++++++++++++++++++++---------- tests/test_optimized_layers.py | 6 +-- 2 files changed, 58 insertions(+), 28 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index d1caac022..bac293948 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -110,8 +110,6 @@ def split_heads( class OptimizedFalconAttention(FalconAttention): def __init__(self, config: FalconConfig): nn.Module.__init__(self) - assert config.new_decoder_architecture - assert config.attention_dropout == 0.0 self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -130,21 +128,26 @@ def __init__(self, config: FalconConfig): # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.beta = self.inv_norm_factor - - qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim - + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) self.new_decoder_architecture = config.new_decoder_architecture self.multi_query = config.multi_query self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias) - self.num_kv_heads = config.num_kv_heads + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 - self._split_heads = partial( - split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim - ) - self.qkv_graph = None - self.input_surface = None - self.static_outputs = None + if self.new_decoder_architecture: + self._split_heads = partial( + split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim + ) + self.qkv_graph = None + self.input_surface = None + self.static_outputs = None def _optimized_apply_qkv(self, hidden_states): if self.qkv_graph is None: @@ -180,7 +183,7 @@ def forward( assert alibi is None assert not output_attentions - if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): + if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) else: fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] @@ -236,18 +239,30 @@ def __init__(self, config: FalconConfig): nn.Module.__init__(self) hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = OptimizedFalconAttention(config) + self.mlp = FalconMLP(config) self.hidden_dropout = config.hidden_dropout self.config = config - assert not self.config.alibi - assert config.new_decoder_architecture - self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.self_attention = OptimizedFalconAttention(config) + + if self.config.alibi or not config.new_decoder_architecture: + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + else: + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.ln_graph = None - self.static_input = None + self.ln_graph = None + self.static_input = None def _optimized_apply_ln(self, hidden_states): if self.ln_graph is None: @@ -283,11 +298,14 @@ def forward( ): residual = hidden_states - if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): - attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + if self.config.new_decoder_architecture: + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): + attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + else: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) else: - attention_layernorm_out = self.ln_attn(hidden_states) - mlp_layernorm_out = self.ln_mlp(hidden_states) + attention_layernorm_out = self.input_layernorm(hidden_states) attn_outputs = self.self_attention( attention_layernorm_out, @@ -300,10 +318,22 @@ def forward( ) attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + outputs = attn_outputs[1:] mlp_output = self.mlp(mlp_layernorm_out) - mlp_output += attention_output + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output += attention_output output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index cb475e2db..93f544c4d 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -6,6 +6,7 @@ from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block +from test_utils import MODEL_NAME KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -93,11 +94,10 @@ def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: return state +@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") @pytest.mark.forked def test_falcon(): - config = AutoDistributedConfig.from_pretrained("tiiuae/falcon-rw-1b") - config.alibi = False - config.new_decoder_architecture = True + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) device = "cpu" tensor_parallel_devices = (device,) From 841a0d52629b79c65afd88eb06c0b2c69129057e Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 01:56:29 +0300 Subject: [PATCH 09/18] Improve test and compatibility --- src/petals/models/falcon/block.py | 77 ++++++++++++++++++++++++++----- tests/test_optimized_layers.py | 11 +++-- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index bac293948..2613b052f 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -16,7 +16,9 @@ FalconDecoderLayer, FalconLinear, FalconMLP, + FalconModel, LayerNorm, + build_alibi_tensor, dropout_add, rotate_half, ) @@ -180,7 +182,6 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ): - assert alibi is None assert not output_attentions if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): @@ -212,6 +213,7 @@ def forward( key_layer = torch.cat((past_key, key_layer), dim=1) value_layer = torch.cat((past_value, value_layer), dim=1) + _, kv_length, _ = key_layer.shape if use_cache: present = (key_layer, value_layer) else: @@ -221,17 +223,59 @@ def forward( key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) - attn_output = F.scaled_dot_product_attention( - query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=True - ) + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + if alibi is None: + attn_output = F.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False + ) - attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - output_tensor = self.dense(attn_output) + output_tensor = self.dense(attn_output) - return output_tensor, present + return output_tensor, present + else: + matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by + # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically + # equivalent and more performant, but there might be a numerical difference. If you're reading this + # and you'd like to experiment and maybe file a PR, feel free! + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + output_tensor = self.dense(context_layer) + + if output_attentions: + return output_tensor, present, attention_probs + else: + return output_tensor, present class OptimizedFalconDecoderLayer(FalconDecoderLayer): @@ -352,20 +396,29 @@ def forward( *args, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[KVCache] = None, use_cache: bool = False, **kwargs, ): assert attention_mask is None + batch_size, seq_length = hidden_states.shape[:2] + if layer_past is not None: layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) outputs = super().forward( hidden_states, *args, - attention_mask=None, - alibi=None, + attention_mask=attention_mask, + alibi=alibi, layer_past=layer_past, use_cache=use_cache, **kwargs, diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 93f544c4d..36b484b70 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -113,9 +113,12 @@ def test_falcon(): ) unopt_block.load_state_dict(block.state_dict()) + cache = unopt_cache = None - for _ in range(3): + for l in range(3): dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype) - block_output = block(dummy_input) - unopt_block_output = unopt_block(dummy_input) - assert torch.allclose(block_output[0], unopt_block_output[0], atol=1e-6, rtol=0) + block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) + unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) + assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l + assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l + assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l From d56f57acd2255b41a9d02eef4675c0122a85ab40 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 02:47:58 +0300 Subject: [PATCH 10/18] Fix rotary embeddings --- src/petals/models/falcon/block.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 2613b052f..1f1a20a54 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -43,6 +43,13 @@ def __init__(self, head_dim: int, base=10000): self.input_surface = None self.static_outputs = None + self.cos_sin( + seq_len=INFERENCE_MAX_LENGTH, + past_key_values_length=0, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + def _optimized_apply_rotary(self, query, key, cos, sin): if self.cuda_graph is None: self.cuda_graph = torch.cuda.CUDAGraph() @@ -80,11 +87,11 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype emb = emb.float() self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype)) - self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype)) + self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype)) return ( - self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], - self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), ) def forward(self, query, key, past_key_values_length=0): From cfaf6c1975973177dfc9a6bc8f3a57da2eea05c9 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 02:51:49 +0300 Subject: [PATCH 11/18] Fix rotary embeddings --- src/petals/models/falcon/block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 1f1a20a54..2571a270e 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -86,8 +86,8 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype)) - self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype)) + self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) return ( self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), From 1f2ef79da3c8cae7f7c8c479ea22af87760a6743 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 04:30:31 +0300 Subject: [PATCH 12/18] WIP disable graphs --- src/petals/models/falcon/block.py | 42 ++++++++++++++++++------------- tests/test_optimized_layers.py | 13 +++++----- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 2571a270e..12342d93c 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -20,13 +20,19 @@ LayerNorm, build_alibi_tensor, dropout_add, - rotate_half, + FalconRotaryEmbedding, ) KVCache = Tuple[torch.Tensor, torch.Tensor] INFERENCE_MAX_LENGTH = 8192 +# @torch.jit.script +def rotate_half(x): + x1, x2 = torch.chunk(x, 2, dim=2) + return torch.cat((-x2, x1), dim=-1) + +# @torch.jit.script def apply_rotary(query, key, cos, sin): return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) @@ -97,14 +103,15 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype def forward(self, query, key, past_key_values_length=0): batch, seq_len, head_dim = query.shape cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) - if seq_len == 1 and torch.is_inference_mode_enabled(): - return self._optimized_apply_rotary(query, key, cos, sin) - else: - return apply_rotary(query, key, cos, sin) - + # print(cos, sin) + # if seq_len == 1 and torch.is_inference_mode_enabled(): + # return self._optimized_apply_rotary(query, key, cos, sin) + # else: + return apply_rotary(query, key, cos, sin) +# @torch.jit.script def split_heads( - fused_qkv: torch.Tensor, num_heads, num_kv_heads, head_dim + fused_qkv: torch.Tensor, num_heads:int, num_kv_heads:int, head_dim:int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch, seq_len, _ = fused_qkv.shape qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) @@ -161,21 +168,21 @@ def __init__(self, config: FalconConfig): def _optimized_apply_qkv(self, hidden_states): if self.qkv_graph is None: self.qkv_graph = torch.cuda.CUDAGraph() - self.static_input = hidden_states + self.input_surface = torch.randn_like(hidden_states) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - fused_qkv = self.query_key_value(hidden_states) + fused_qkv = self.query_key_value(self.input_surface) self._split_heads(fused_qkv) torch.cuda.current_stream().wait_stream(s) with torch.cuda.graph(self.qkv_graph): - static_fused_qkv = self.query_key_value(hidden_states) + static_fused_qkv = self.query_key_value(self.input_surface) self.static_outputs = self._split_heads(static_fused_qkv) - self.static_input.copy_(hidden_states) + self.input_surface.copy_(hidden_states) self.qkv_graph.replay() return tuple(o.detach() for o in self.static_outputs) @@ -191,12 +198,12 @@ def forward( ): assert not output_attentions - if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): - query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) - else: - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + # if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): + # query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) + # else: + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) num_kv_heads = self.num_heads batch_size, query_length, _, _ = query_layer.shape @@ -314,6 +321,7 @@ def __init__(self, config: FalconConfig): self.ln_graph = None self.static_input = None + self.static_outputs = None def _optimized_apply_ln(self, hidden_states): if self.ln_graph is None: diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 36b484b70..08aa24348 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -116,9 +116,10 @@ def test_falcon(): cache = unopt_cache = None for l in range(3): - dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype) - block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) - unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) - assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l - assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l - assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l + with torch.inference_mode(): + dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype) + block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) + unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) + assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l + assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l + assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l From b941df5d2ffee682d3bd123119602fe42f8d6cea Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 04:38:00 +0300 Subject: [PATCH 13/18] Fix formatting --- src/petals/models/falcon/block.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 12342d93c..60d6e5dcd 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -17,10 +17,10 @@ FalconLinear, FalconMLP, FalconModel, + FalconRotaryEmbedding, LayerNorm, build_alibi_tensor, dropout_add, - FalconRotaryEmbedding, ) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -109,9 +109,10 @@ def forward(self, query, key, past_key_values_length=0): # else: return apply_rotary(query, key, cos, sin) + # @torch.jit.script def split_heads( - fused_qkv: torch.Tensor, num_heads:int, num_kv_heads:int, head_dim:int + fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch, seq_len, _ = fused_qkv.shape qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) From 177669e97f3777275195fd3cb4cf94862242f470 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 11:25:49 +0300 Subject: [PATCH 14/18] Rollback CUDA graphs --- src/petals/models/falcon/block.py | 147 +++--------------------------- 1 file changed, 13 insertions(+), 134 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 60d6e5dcd..31eda92b9 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -4,7 +4,6 @@ See commit history for authorship. """ import math -from functools import partial from typing import Optional, Tuple import torch @@ -17,25 +16,15 @@ FalconLinear, FalconMLP, FalconModel, - FalconRotaryEmbedding, LayerNorm, build_alibi_tensor, dropout_add, + rotate_half, ) KVCache = Tuple[torch.Tensor, torch.Tensor] INFERENCE_MAX_LENGTH = 8192 -# @torch.jit.script -def rotate_half(x): - x1, x2 = torch.chunk(x, 2, dim=2) - return torch.cat((-x2, x1), dim=-1) - - -# @torch.jit.script -def apply_rotary(query, key, cos, sin): - return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) - class OptimizedFalconRotaryEmbedding(nn.Module): def __init__(self, head_dim: int, base=10000): @@ -45,38 +34,6 @@ def __init__(self, head_dim: int, base=10000): self.head_dim = head_dim self.seq_len_cached = -1 - self.cuda_graph = None - self.input_surface = None - self.static_outputs = None - - self.cos_sin( - seq_len=INFERENCE_MAX_LENGTH, - past_key_values_length=0, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _optimized_apply_rotary(self, query, key, cos, sin): - if self.cuda_graph is None: - self.cuda_graph = torch.cuda.CUDAGraph() - self.input_surface = (query, key, cos, sin) - - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): - apply_rotary(*self.input_surface) - torch.cuda.current_stream().wait_stream(s) - - with torch.cuda.graph(self.cuda_graph): - self.static_outputs = apply_rotary(*self.input_surface) - - inputs = (query, key, cos, sin) - for static_input, data in zip(self.input_surface, inputs): - static_input.copy_(data) - self.cuda_graph.replay() - return tuple(o.detach() for o in self.static_outputs) - def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: total_length = seq_len + past_key_values_length if self.seq_len_cached == -1: @@ -84,16 +41,17 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype total_length = max(INFERENCE_MAX_LENGTH, total_length) if total_length > self.seq_len_cached: - self.seq_len_cached = total_length - t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) + with torch.inference_mode(False): + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() - self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) return ( self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), @@ -103,25 +61,7 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype def forward(self, query, key, past_key_values_length=0): batch, seq_len, head_dim = query.shape cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) - # print(cos, sin) - # if seq_len == 1 and torch.is_inference_mode_enabled(): - # return self._optimized_apply_rotary(query, key, cos, sin) - # else: - return apply_rotary(query, key, cos, sin) - - -# @torch.jit.script -def split_heads( - fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch, seq_len, _ = fused_qkv.shape - qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) - query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) - key = torch.broadcast_to(key, query.shape) - value = torch.broadcast_to(value, query.shape) - - query, key, value = [x.flatten(2, 3) for x in (query, key, value)] - return query, key, value + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) class OptimizedFalconAttention(FalconAttention): @@ -158,35 +98,6 @@ def __init__(self, config: FalconConfig): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 - if self.new_decoder_architecture: - self._split_heads = partial( - split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim - ) - self.qkv_graph = None - self.input_surface = None - self.static_outputs = None - - def _optimized_apply_qkv(self, hidden_states): - if self.qkv_graph is None: - self.qkv_graph = torch.cuda.CUDAGraph() - self.input_surface = torch.randn_like(hidden_states) - - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): - fused_qkv = self.query_key_value(self.input_surface) - self._split_heads(fused_qkv) - torch.cuda.current_stream().wait_stream(s) - - with torch.cuda.graph(self.qkv_graph): - static_fused_qkv = self.query_key_value(self.input_surface) - self.static_outputs = self._split_heads(static_fused_qkv) - - self.input_surface.copy_(hidden_states) - self.qkv_graph.replay() - return tuple(o.detach() for o in self.static_outputs) - def forward( self, hidden_states: torch.Tensor, @@ -199,9 +110,6 @@ def forward( ): assert not output_attentions - # if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): - # query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) - # else: fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) @@ -320,32 +228,6 @@ def __init__(self, config: FalconConfig): self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.ln_graph = None - self.static_input = None - self.static_outputs = None - - def _optimized_apply_ln(self, hidden_states): - if self.ln_graph is None: - self.ln_graph = torch.cuda.CUDAGraph() - self.static_input = hidden_states - - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): - self.ln_attn(hidden_states) - self.ln_mlp(hidden_states) - torch.cuda.current_stream().wait_stream(s) - - with torch.cuda.graph(self.ln_graph): - ln_attn_output = self.ln_attn(hidden_states) - ln_mlp_output = self.ln_mlp(hidden_states) - self.static_outputs = (ln_attn_output, ln_mlp_output) - - self.static_input.copy_(hidden_states) - self.ln_graph.replay() - return tuple(o.detach() for o in self.static_outputs) - def forward( self, hidden_states: torch.Tensor, @@ -359,11 +241,8 @@ def forward( residual = hidden_states if self.config.new_decoder_architecture: - if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled(): - attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) - else: - attention_layernorm_out = self.ln_attn(hidden_states) - mlp_layernorm_out = self.ln_mlp(hidden_states) + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) else: attention_layernorm_out = self.input_layernorm(hidden_states) From 91f6248535d6d248fbea6891e657908358bd4bf3 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 12:00:59 +0300 Subject: [PATCH 15/18] Run tests on CUDA and CPU, --- tests/test_optimized_layers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 08aa24348..20aae6fee 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -95,11 +95,14 @@ def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: @pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) @pytest.mark.forked -def test_falcon(): +def test_falcon(device): + if device == "cuda:0" and not torch.cuda.is_available(): + pytest.skip("CUDA tests can be run only in CUDA-enabled setups") + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) - device = "cpu" tensor_parallel_devices = (device,) dtype = torch.bfloat16 quant_type = QuantType.NONE From ea6c037c8bac7ef946f968a5a8b26e9437af4dba Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 12:01:10 +0300 Subject: [PATCH 16/18] Enable CUDA graphs only on CUDA --- src/petals/models/falcon/block.py | 124 ++++++++++++++++++++++++++++-- 1 file changed, 118 insertions(+), 6 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 31eda92b9..5928d23c4 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -4,6 +4,7 @@ See commit history for authorship. """ import math +from functools import partial from typing import Optional, Tuple import torch @@ -26,6 +27,10 @@ INFERENCE_MAX_LENGTH = 8192 +def apply_rotary(query, key, cos, sin): + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + + class OptimizedFalconRotaryEmbedding(nn.Module): def __init__(self, head_dim: int, base=10000): super().__init__() @@ -34,6 +39,31 @@ def __init__(self, head_dim: int, base=10000): self.head_dim = head_dim self.seq_len_cached = -1 + self.cuda_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_rotary(self, query, key, cos, sin): + if self.cuda_graph is None: + self.cuda_graph = torch.cuda.CUDAGraph() + self.input_surface = (query, key, cos, sin) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + apply_rotary(*self.input_surface) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.cuda_graph): + self.static_outputs = apply_rotary(*self.input_surface) + + inputs = (query, key, cos, sin) + for static_input, data in zip(self.input_surface, inputs): + static_input.copy_(data) + self.cuda_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: total_length = seq_len + past_key_values_length if self.seq_len_cached == -1: @@ -61,7 +91,23 @@ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype def forward(self, query, key, past_key_values_length=0): batch, seq_len, head_dim = query.shape cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) - return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda": + return self._optimized_apply_rotary(query, key, cos, sin) + else: + return apply_rotary(query, key, cos, sin) + + +def split_heads( + fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) + query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value class OptimizedFalconAttention(FalconAttention): @@ -98,6 +144,35 @@ def __init__(self, config: FalconConfig): self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + if self.new_decoder_architecture: + self._split_heads = partial( + split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim + ) + self.qkv_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_qkv(self, hidden_states): + if self.qkv_graph is None: + self.qkv_graph = torch.cuda.CUDAGraph() + self.input_surface = torch.randn_like(hidden_states) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + fused_qkv = self.query_key_value(self.input_surface) + self._split_heads(fused_qkv) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.qkv_graph): + static_fused_qkv = self.query_key_value(self.input_surface) + self.static_outputs = self._split_heads(static_fused_qkv) + + self.input_surface.copy_(hidden_states) + self.qkv_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + def forward( self, hidden_states: torch.Tensor, @@ -110,9 +185,17 @@ def forward( ): assert not output_attentions - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + if ( + self.new_decoder_architecture + and hidden_states.size(1) == 1 + and torch.is_inference_mode_enabled() + and hidden_states.device.type == "cuda" + ): + query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) + else: + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) num_kv_heads = self.num_heads batch_size, query_length, _, _ = query_layer.shape @@ -228,6 +311,32 @@ def __init__(self, config: FalconConfig): self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_graph = None + self.static_input = None + self.static_outputs = None + + def _optimized_apply_ln(self, hidden_states): + if self.ln_graph is None: + self.ln_graph = torch.cuda.CUDAGraph() + self.static_input = hidden_states + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.ln_attn(hidden_states) + self.ln_mlp(hidden_states) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.ln_graph): + ln_attn_output = self.ln_attn(hidden_states) + ln_mlp_output = self.ln_mlp(hidden_states) + self.static_outputs = (ln_attn_output, ln_mlp_output) + + self.static_input.copy_(hidden_states) + self.ln_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + def forward( self, hidden_states: torch.Tensor, @@ -241,8 +350,11 @@ def forward( residual = hidden_states if self.config.new_decoder_architecture: - attention_layernorm_out = self.ln_attn(hidden_states) - mlp_layernorm_out = self.ln_mlp(hidden_states) + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + else: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) else: attention_layernorm_out = self.input_layernorm(hidden_states) From 2c27c19df44f7c30871b043ef8cf4ecbecfb215a Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 14:42:13 +0300 Subject: [PATCH 17/18] Do not fuse split_heads with qkv This is most likely due to bitsandbytes performing work not captured in the graph --- src/petals/models/falcon/block.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index 5928d23c4..a510abaa1 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -148,29 +148,27 @@ def __init__(self, config: FalconConfig): self._split_heads = partial( split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim ) - self.qkv_graph = None + self.split_graph = None self.input_surface = None self.static_outputs = None - def _optimized_apply_qkv(self, hidden_states): - if self.qkv_graph is None: - self.qkv_graph = torch.cuda.CUDAGraph() - self.input_surface = torch.randn_like(hidden_states) + def _optimized_split_heads(self, fused_qkv): + if self.split_graph is None: + self.split_graph = torch.cuda.CUDAGraph() + self.input_surface = fused_qkv s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - fused_qkv = self.query_key_value(self.input_surface) self._split_heads(fused_qkv) torch.cuda.current_stream().wait_stream(s) - with torch.cuda.graph(self.qkv_graph): - static_fused_qkv = self.query_key_value(self.input_surface) - self.static_outputs = self._split_heads(static_fused_qkv) + with torch.cuda.graph(self.split_graph): + self.static_outputs = self._split_heads(self.input_surface) - self.input_surface.copy_(hidden_states) - self.qkv_graph.replay() + self.input_surface.copy_(fused_qkv) + self.split_graph.replay() return tuple(o.detach() for o in self.static_outputs) def forward( @@ -185,15 +183,16 @@ def forward( ): assert not output_attentions + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + if ( self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda" ): - query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states) + query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv) else: - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) From 52baffb0569a42fb141c7a6e31e29e69165e7674 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 14:43:37 +0300 Subject: [PATCH 18/18] Update test_optimized_layers --- tests/test_optimized_layers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 20aae6fee..5baa1a2cf 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -118,11 +118,11 @@ def test_falcon(device): unopt_block.load_state_dict(block.state_dict()) cache = unopt_cache = None - for l in range(3): - with torch.inference_mode(): - dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype) + with torch.inference_mode(): + for length in [10, 1, 1, 1]: + dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype) block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) - assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l - assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l - assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l + assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length + assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length + assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length