Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward backward #559

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
first attempt
  • Loading branch information
dvmazur committed Feb 1, 2024
commit 321b0483a1192e05b6918a3e13a0efddbeb0a115
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
accelerate>=0.22.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2
Expand All @@ -47,7 +47,7 @@ install_requires =
cpufeature>=0.2.0; platform_machine == "x86_64"
packaging>=20.9
sentencepiece>=0.1.99
peft==0.5.0
peft==0.7.1
safetensors>=0.3.1
Dijkstar>=2.6.0

Expand All @@ -61,4 +61,4 @@ dev =
psutil

[options.packages.find]
where = src
where = src
7 changes: 4 additions & 3 deletions src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@


if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
# assert (
# version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
# ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
...


def _override_bfloat16_mode_default():
Expand Down
1 change: 1 addition & 0 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
self._position = 0
self._max_length = max_length
self.output_ids = None
self.past_key_values = None

@property
def num_blocks(self) -> int:
Expand Down
34 changes: 27 additions & 7 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
import contextlib
import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
from typing import Any, ContextManager, Dict, List, Optional, Tuple

import torch
import transformers
from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from transformers.cache_utils import Cache, DynamicCache

from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from

from torch import Tensor

logger = get_logger(__name__)


@dataclasses.dataclass(frozen=True)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
class RemotePastKeyValues(Cache):
"""only keeps the number of seen tokens. pretends to be a legit cache"""
def __init__(self) -> None:
super().__init__()
self.seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens

def get_max_length(self) -> int | None:
return None

def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen

hypo_ids: Optional[torch.LongTensor] = None
def reorder_cache(self, beam_idx):
pass

def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.seen_tokens=})"


_skipped_tokens = ContextVar("skipped_tokens", default=0)
Expand Down Expand Up @@ -113,6 +129,10 @@ def generate(
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))

if "past_key_values" not in kwargs:
rpkv = RemotePastKeyValues()
rpkv.update_seen(session.position)
kwargs["past_key_values"] = rpkv
result = super().generate(inputs, *args, **kwargs)

sequences = result.sequences if isinstance(result, ModelOutput) else result
Expand Down
14 changes: 9 additions & 5 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
repeat_kv,
rotate_half,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from petals.utils.cuda_graphs import make_inference_graphed_callable

Expand Down Expand Up @@ -84,8 +85,8 @@ def forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]

if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
Expand All @@ -98,7 +99,7 @@ def forward(
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -244,8 +245,11 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)

outputs = super().forward(
Expand Down
6 changes: 5 additions & 1 deletion src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,20 @@ def forward(
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)

past_key_values = past_key_values if past_key_values is not None else RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))

# Remove prefix
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]

# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
Expand Down