-
Notifications
You must be signed in to change notification settings - Fork 512
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds: - Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B. - CI tests for Falcon-RW-1B. - `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab). Limitations: - Backward pass support is broken for now, will be fixed in #500. Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
- Loading branch information
Showing
9 changed files
with
356 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from petals.models.bloom import * | ||
from petals.models.falcon import * | ||
from petals.models.llama import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from petals.models.falcon.block import WrappedFalconBlock | ||
from petals.models.falcon.config import DistributedFalconConfig | ||
from petals.models.falcon.model import ( | ||
DistributedFalconForCausalLM, | ||
DistributedFalconForSequenceClassification, | ||
DistributedFalconModel, | ||
) | ||
from petals.utils.auto_config import register_model_classes | ||
|
||
register_model_classes( | ||
config=DistributedFalconConfig, | ||
model=DistributedFalconModel, | ||
model_for_causal_lm=DistributedFalconForCausalLM, | ||
model_for_sequence_classification=DistributedFalconForSequenceClassification, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
Falcon intermediate layer | ||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py | ||
See commit history for authorship. | ||
""" | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor | ||
|
||
KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
class WrappedFalconBlock(FalconDecoderLayer): | ||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
*args, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
alibi: Optional[torch.Tensor] = None, | ||
layer_past: Optional[KVCache] = None, | ||
use_cache: bool = False, | ||
**kwargs | ||
): | ||
batch_size, seq_length = hidden_states.shape[:2] | ||
|
||
if layer_past is not None: | ||
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) | ||
past_length = 0 if layer_past is None else layer_past[0].shape[1] | ||
seq_length_with_past = seq_length + past_length | ||
|
||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) | ||
if alibi is None and self.config.alibi: | ||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) | ||
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) | ||
|
||
outputs = super().forward( | ||
hidden_states, | ||
*args, | ||
attention_mask=attention_mask, | ||
alibi=alibi, | ||
layer_past=layer_past, | ||
use_cache=use_cache, | ||
**kwargs | ||
) | ||
|
||
if use_cache: | ||
present_key_value = outputs[-1] | ||
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) | ||
outputs = outputs[:-1] + (present_key_value,) | ||
|
||
return outputs | ||
|
||
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: | ||
key_states, value_states = key_value | ||
|
||
key_states = key_states.permute(0, 2, 1) | ||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] | ||
|
||
if self.config.new_decoder_architecture: | ||
key_states = self._expand_states(key_states) | ||
value_states = self._expand_states(value_states) | ||
|
||
return (key_states, value_states) | ||
|
||
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: | ||
key_states, value_states = key_value | ||
|
||
if self.config.new_decoder_architecture: | ||
key_states = self._collapse_states(key_states) | ||
value_states = self._collapse_states(value_states) | ||
|
||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] | ||
key_states = key_states.permute(0, 2, 1) | ||
|
||
return (key_states, value_states) | ||
|
||
def _expand_states(self, state: torch.Tensor) -> torch.Tensor: | ||
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape | ||
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads | ||
|
||
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) | ||
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy | ||
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy | ||
return state | ||
|
||
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: | ||
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape | ||
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads | ||
|
||
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) | ||
state = state[:, :, 0] | ||
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
from typing import Optional, Union | ||
|
||
from hivemind import get_logger | ||
from transformers.models.falcon import FalconConfig | ||
from transformers.models.falcon.modeling_falcon import FalconAttention | ||
|
||
from petals.client.config import ClientConfig | ||
from petals.client.lm_head import LMHeadConfig | ||
from petals.client.ptune import PTuneConfig | ||
from petals.models.falcon.block import WrappedFalconBlock | ||
from petals.utils.auto_config import DefaultRevisionMixin | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig): | ||
block_class = WrappedFalconBlock | ||
attn_class = FalconAttention | ||
block_prefix = "transformer.h" | ||
|
||
@property | ||
def num_key_value_groups(self) -> int: | ||
if self.new_decoder_architecture: | ||
return self.num_attention_heads // self.num_kv_heads | ||
if self.multi_query: | ||
return self.num_attention_heads | ||
return 1 | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs | ||
): | ||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) | ||
if loading_from_repo and dht_prefix is None: | ||
dht_prefix = str(model_name_or_path) | ||
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts | ||
dht_prefix = dht_prefix.replace(".", "-") | ||
logger.info(f"Using DHT prefix: {dht_prefix}") | ||
|
||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) | ||
config = result[0] if isinstance(result, tuple) else result | ||
if config.pad_token_id is None: | ||
config.pad_token_id = 0 | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from typing import Optional | ||
|
||
import hivemind | ||
import torch | ||
import torch.nn as nn | ||
from hivemind.utils.logging import get_logger | ||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions | ||
from transformers.models.falcon import ( | ||
FalconForCausalLM, | ||
FalconForSequenceClassification, | ||
FalconModel, | ||
FalconPreTrainedModel, | ||
) | ||
|
||
from petals.client.from_pretrained import FromPretrainedMixin | ||
from petals.client.lm_head import LMHead | ||
from petals.client.ptune import PTuneMixin | ||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues | ||
from petals.client.remote_sequential import RemoteSequential | ||
from petals.models.falcon.config import DistributedFalconConfig | ||
from petals.utils.auto_config import DefaultRevisionMixin | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel): | ||
"""FalconModel, but all transformer layers are hosted by the swarm""" | ||
|
||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing | ||
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."] | ||
|
||
config_class = DistributedFalconConfig | ||
|
||
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None): | ||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization | ||
super().__init__(config) | ||
assert len(self.h) == 0 | ||
config.num_hidden_layers = n_layer | ||
|
||
self.h = RemoteSequential(config, dht=dht) | ||
|
||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm | ||
self.init_prompts(config) | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[RemotePastKeyValues] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
head_mask: Optional[torch.LongTensor] = None, | ||
inputs_embeds: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
): | ||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
elif input_ids is not None: | ||
input_shape = input_ids.size() | ||
input_ids = input_ids.view(-1, input_shape[-1]) | ||
elif inputs_embeds is not None: | ||
input_shape = inputs_embeds.size()[:-1] | ||
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
# The causal mask will be added on the server-side | ||
assert ( | ||
attention_mask is None or (attention_mask == 1).all() | ||
), f"Custom attention masks are not supported, {attention_mask=}" | ||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" | ||
assert use_cache is None or use_cache, f"{use_cache=} is not supported" | ||
assert not output_attentions, f"{output_attentions=} is not supported" | ||
assert not output_hidden_states, f"{output_hidden_states=} is not supported" | ||
assert return_dict is None or return_dict, f"{return_dict=} is not supported" | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.word_embeddings(input_ids) | ||
|
||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: | ||
batch_size = inputs_embeds.shape[0] | ||
prompts, intermediate_prompts = self.get_prompt(batch_size) | ||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) | ||
else: | ||
prompts = intermediate_prompts = None | ||
|
||
hidden_states = self.word_embeddings_layernorm(inputs_embeds) | ||
output_shape = input_shape + (hidden_states.size(-1),) | ||
|
||
hidden_states = self.h( | ||
hidden_states, | ||
prompts=intermediate_prompts, | ||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, | ||
) | ||
|
||
# Remove prefix | ||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: | ||
hidden_states = hidden_states[:, self.pre_seq_len :] | ||
|
||
# Add last hidden state | ||
hidden_states = self.ln_f(hidden_states) | ||
hidden_states = hidden_states.view(output_shape) | ||
return BaseModelOutputWithPastAndCrossAttentions( | ||
last_hidden_state=hidden_states, | ||
past_key_values=RemotePastKeyValues(), | ||
hidden_states=None, | ||
attentions=None, | ||
) | ||
|
||
@property | ||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin | ||
return nn.Identity() | ||
|
||
|
||
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM): | ||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing | ||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected | ||
|
||
config_class = DistributedFalconConfig | ||
|
||
def __init__(self, config: DistributedFalconConfig): | ||
FalconPreTrainedModel.__init__(self, config) | ||
self.transformer = DistributedFalconModel(config) | ||
self.lm_head = LMHead(config) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def get_output_embeddings(self): | ||
return self.lm_head | ||
|
||
|
||
class DistributedFalconForSequenceClassification( | ||
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification | ||
): | ||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing | ||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected | ||
|
||
config_class = DistributedFalconConfig | ||
|
||
def __init__(self, config: DistributedFalconConfig): | ||
FalconPreTrainedModel.__init__(self, config) | ||
self.num_labels = config.num_labels | ||
|
||
self.transformer = DistributedFalconModel(config) | ||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.