Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix p2p pushing in rpc_inference, support transformers 4.38.2 #563

Merged
merged 4 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support transformers 4.38.2
  • Loading branch information
justheuristic committed Mar 17, 2024
commit 627839522749f24381d15531738b20114cb66dbf
2 changes: 1 addition & 1 deletion src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.39.0")
version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"


Expand Down
17 changes: 13 additions & 4 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
assert position_ids is None
if position_ids is None:
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
position_ids = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
).unsqueeze(0)

bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
Expand Down Expand Up @@ -84,9 +90,8 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)

if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -160,6 +165,8 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -190,6 +197,8 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = residual + hidden_states
Expand Down
2 changes: 2 additions & 0 deletions src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> BaseModelOutputWithPast:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
Expand All @@ -62,6 +63,7 @@ def forward(
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert cache_position is None, "cache_position is only supported for dedicated inference"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
Expand Down
Loading