Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump transformers and accelerate versions #554

Merged
merged 19 commits into from
Feb 15, 2024
Prev Previous commit
Next Next commit
fix optimized layers test
  • Loading branch information
dvmazur committed Feb 2, 2024
commit 61456d99689349feed54b8b536c7cd5215bbe13f
22 changes: 14 additions & 8 deletions tests/test_optimized_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from transformers.cache_utils import DynamicCache
from test_utils import MODEL_NAME

KVCache = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -116,6 +118,8 @@ def forward(
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
elif use_cache:
past_key_value = DynamicCache()

if position_ids is None:
device = hidden_states.device
Expand All @@ -131,8 +135,9 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length

attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)

outputs = super().forward(
Expand All @@ -156,19 +161,20 @@ def forward(

def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
) -> DynamicCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
past_key_values = ((key_states, value_states),)
return DynamicCache.from_legacy_cache(past_key_values)

def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
self, key_value: DynamicCache, batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states, value_states = key_value.to_legacy_cache()[0]
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
Expand All @@ -178,7 +184,7 @@ def _reorder_cache_from_llama_to_bloom(


@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
# @pytest.mark.forked
def test_optimized_block(device):
if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
Expand All @@ -195,7 +201,7 @@ def test_optimized_block(device):
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")

Expand Down