Skip to content

Commit

Permalink
Force use_cache=True (bigscience-workshop#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov authored and Dobromir Popov committed Sep 5, 2023
1 parent 29f7249 commit 1606bbe
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ install_requires =
transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind @ git+https://github.com/learning-at-home/hivemind
hivemind==1.1.10.post2
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2
Expand Down
3 changes: 1 addition & 2 deletions src/petals/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand All @@ -63,7 +63,6 @@ def forward(
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
Expand Down
3 changes: 1 addition & 2 deletions src/petals/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand All @@ -65,7 +65,6 @@ def forward(
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
Expand Down

0 comments on commit 1606bbe

Please sign in to comment.