Skip to content

Commit

Permalink
Bump transformers and accelerate versions (#554)
Browse files Browse the repository at this point in the history
Bump versions for transformers and accelerate, remove falcon-rw-1b CI tests
  • Loading branch information
dvmazur committed Feb 15, 2024
1 parent d59c15c commit 0d91bbd
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 29 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ jobs:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.41.1
accelerate>=0.22.0
accelerate>=0.27.2
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2
Expand Down
6 changes: 3 additions & 3 deletions src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs

__version__ = "2.3.0.dev1"
__version__ = "2.3.0.dev2"


if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"


def _override_bfloat16_mode_default():
Expand Down
1 change: 1 addition & 0 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
self._position = 0
self._max_length = max_length
self.output_ids = None
self.past_key_values = None

@property
def num_blocks(self) -> int:
Expand Down
31 changes: 26 additions & 5 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import contextlib
import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
from typing import Any, ContextManager, Dict, List, Optional, Tuple

import torch
import transformers
from hivemind.utils.logging import get_logger
from torch import Tensor
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import ModelOutput

from petals.client.inference_session import InferenceSession
Expand All @@ -15,15 +17,29 @@
logger = get_logger(__name__)


@dataclasses.dataclass(frozen=True)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
class RemotePastKeyValues(Cache):
"""only keeps the number of seen tokens. pretends to be a legit cache"""

hypo_ids: Optional[torch.LongTensor] = None
def __init__(self) -> None:
super().__init__()
self.seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None

def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens

def get_max_length(self) -> Optional[int]:
return None

def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen

def reorder_cache(self, beam_idx):
pass


_skipped_tokens = ContextVar("skipped_tokens", default=0)

Expand Down Expand Up @@ -113,6 +129,11 @@ def generate(
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))

if self._supports_cache_class and "past_key_values" not in kwargs:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values

result = super().generate(inputs, *args, **kwargs)

sequences = result.sequences if isinstance(result, ModelOutput) else result
Expand Down
9 changes: 8 additions & 1 deletion src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Tuple

import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor


Expand All @@ -26,7 +27,13 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)
60 changes: 59 additions & 1 deletion src/petals/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel

Expand Down Expand Up @@ -92,12 +93,16 @@ def forward(
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]

if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))

# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
Expand All @@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
_supports_cache_class = True

config_class = DistributedBloomConfig

Expand All @@ -118,6 +124,58 @@ def __init__(self, config: DistributedBloomConfig):
# Initialize weights and apply final processing
self.post_init()

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]

if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs

def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values

def get_output_embeddings(self):
return self.lm_head

Expand Down
12 changes: 8 additions & 4 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
Expand Down Expand Up @@ -84,8 +85,8 @@ def forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]

if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -244,8 +245,11 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)

outputs = super().forward(
Expand Down
7 changes: 6 additions & 1 deletion src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,21 @@ def forward(
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)

if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))

# Remove prefix
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]

# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
Expand Down
4 changes: 1 addition & 3 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@


def check_peft_repository(repo_id: str) -> bool:
fs = HfFileSystem()
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
return len(list_of_files) > 0
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")


def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
Expand Down
20 changes: 13 additions & 7 deletions tests/test_optimized_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest
import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

Expand Down Expand Up @@ -116,6 +118,8 @@ def forward(
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
elif use_cache:
past_key_value = DynamicCache()

if position_ids is None:
device = hidden_states.device
Expand All @@ -131,8 +135,9 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length

attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)

outputs = super().forward(
Expand All @@ -156,19 +161,20 @@ def forward(

def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
) -> DynamicCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
past_key_values = ((key_states, value_states),)
return DynamicCache.from_legacy_cache(past_key_values)

def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
self, key_value: DynamicCache, batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states, value_states = key_value.to_legacy_cache()[0]
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
Expand All @@ -195,7 +201,7 @@ def test_optimized_block(device):
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")

Expand Down

0 comments on commit 0d91bbd

Please sign in to comment.