Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump transformers and accelerate versions #554

Merged
merged 19 commits into from
Feb 15, 2024
Merged
Prev Previous commit
Next Next commit
remove falcon from ci tests
  • Loading branch information
dvmazur committed Feb 11, 2024
commit 25ee8ecfd045e234d044e78ef5a76fa92173faea
2 changes: 0 additions & 2 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ jobs:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
Expand Down
6 changes: 3 additions & 3 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
import torch
import transformers
from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from torch import Tensor
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import ModelOutput

from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from

from torch import Tensor

logger = get_logger(__name__)


class RemotePastKeyValues(Cache):
"""only keeps the number of seen tokens. pretends to be a legit cache"""

def __init__(self) -> None:
super().__init__()
self.seen_tokens = 0
Expand Down
7 changes: 2 additions & 5 deletions src/petals/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
from transformers.cache_utils import Cache

from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
Expand Down Expand Up @@ -124,7 +124,6 @@ def __init__(self, config: DistributedBloomConfig):
# Initialize weights and apply final processing
self.post_init()


def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
Expand Down Expand Up @@ -174,11 +173,9 @@ def prepare_inputs_for_generation(
)
return model_inputs


def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values



def get_output_embeddings(self):
return self.lm_head

Expand Down
4 changes: 2 additions & 2 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
Expand All @@ -19,7 +20,6 @@
repeat_kv,
rotate_half,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from petals.utils.cuda_graphs import make_inference_graphed_callable

Expand Down Expand Up @@ -99,7 +99,7 @@ def forward(
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down
2 changes: 1 addition & 1 deletion src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(
# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
Expand Down
Loading