Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump transformers and accelerate versions #554

Merged
merged 19 commits into from
Feb 15, 2024
Merged
Prev Previous commit
Next Next commit
bloom
  • Loading branch information
dvmazur committed Feb 11, 2024
commit 72e1f0059910ff3de7ffd00a2304199129d98b48
4 changes: 4 additions & 0 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import contextlib

Check failure on line 1 in src/petals/client/remote_generation.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.
import dataclasses
from contextvars import ContextVar
from typing import Any, ContextManager, Dict, List, Optional, Tuple
Expand All @@ -25,6 +25,9 @@
self.seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None

def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens

Expand Down Expand Up @@ -130,6 +133,7 @@
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values

result = super().generate(inputs, *args, **kwargs)

sequences = result.sequences if isinstance(result, ModelOutput) else result
Expand Down
9 changes: 8 additions & 1 deletion src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Tuple

import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor


Expand All @@ -26,7 +27,13 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)
63 changes: 62 additions & 1 deletion src/petals/models/bloom/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional

Check failure on line 1 in src/petals/models/bloom/model.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.

import hivemind
import torch
Expand All @@ -6,6 +6,7 @@
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
from transformers.cache_utils import Cache

from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
Expand Down Expand Up @@ -92,12 +93,16 @@
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]

if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))

# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
Expand All @@ -107,6 +112,7 @@
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
_supports_cache_class = True

config_class = DistributedBloomConfig

Expand All @@ -118,6 +124,61 @@
# Initialize weights and apply final processing
self.post_init()


def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]

if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs


def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values


def get_output_embeddings(self):
return self.lm_head

Expand Down
Loading