Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split long sequences into chunks #403

Merged
merged 8 commits into from
Jul 22, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Simplify code
  • Loading branch information
borzunov committed Jul 22, 2023
commit 124a3bb59fea09f6b7a186625269914166e203f2
16 changes: 6 additions & 10 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,26 @@ def inference_step(
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
seq_len = hidden_states.shape[1]
max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
num_chunks = max(1, (hidden_states.shape[1] - 1) // max_chunk_length + 1) # divide, round up

with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)

output_hidden_states = torch.empty_like(hidden_states) if num_chunks > 1 else None

output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
for chunk_i in range(num_chunks):
hidden_states_chunk = hidden_states[:, chunk_i * max_chunk_length : (chunk_i + 1) * max_chunk_length, :]
for offset in range(0, seq_len, max_chunk_length):
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
output_hidden_states_chunk, new_kvs = self.module.forward(
hidden_states_chunk, layer_past=layer_past, use_cache=True
)
if num_chunks > 1:
output_hidden_states[
:, chunk_i * max_chunk_length : (chunk_i + 1) * max_chunk_length, :
] = output_hidden_states_chunk
if seq_len > max_chunk_length:
output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
else:
output_hidden_states = output_hidden_states_chunk # saves one memcopy
layer_past = new_kvs
del layer_past

self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (output_hidden_states,)
Expand Down
Loading