Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch adapters slightly faster #353

Merged
merged 9 commits into from
Jul 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
DAED HELL SALAD
  • Loading branch information
borzunov authored and Your Name committed Jul 14, 2023
commit e092ec264e26112f95594c5c600c3c9da22f7974
17 changes: 11 additions & 6 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
from petals.utils.peft import using_adapter

logger = get_logger(__name__)


class TransformerBackend(ModuleBackend):
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""

_peft_module = None

def __init__(
self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
):
import petals.utils.peft as _peft_module

self._peft_module = _peft_module

super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.config = config
Expand Down Expand Up @@ -83,12 +88,12 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S

def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with using_adapter(active_adapter):
with self._peft_module.using_adapter(active_adapter):
return super().forward(*inputs)

def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with using_adapter(active_adapter):
with self._peft_module.using_adapter(active_adapter):
return super().backward(*inputs)

@torch.inference_mode()
Expand All @@ -99,9 +104,9 @@ def inference_step(
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_adapter(
inference_info.active_adapter
):
with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
Expand Down
Loading