Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch adapters slightly faster #353

Merged
merged 9 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
class TransformerBackend(ModuleBackend):
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""

_peft_module = None

def __init__(
self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
):
import petals.utils.peft as _peft_module

self._peft_module = _peft_module

super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.config = config
Expand Down Expand Up @@ -82,13 +88,13 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S

def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
self.load_adapter_(active_adapter)
return super().forward(*inputs)
with self._peft_module.using_adapter(active_adapter):
return super().forward(*inputs)

def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
self.load_adapter_(active_adapter)
return super().backward(*inputs)
with self._peft_module.using_adapter(active_adapter):
return super().backward(*inputs)

@torch.inference_mode()
def inference_step(
Expand All @@ -98,8 +104,9 @@ def inference_step(
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
self.load_adapter_(inference_info.active_adapter)
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
Expand Down Expand Up @@ -150,22 +157,6 @@ def shutdown(self):
for p in self.module.parameters():
p.data = dummy

def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""

# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt

loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
loaded = True

if active_adapter and not loaded:
raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded")


def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
Expand Down
18 changes: 13 additions & 5 deletions src/petals/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
adapters: Optional[Sequence[str]],
dht_prefix: str,
push_manager: multiprocessing.managers.SyncManager,
session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue
Expand All @@ -81,6 +82,7 @@ def __init__(
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
self.dht_prefix = dht_prefix
self.adapters = adapters
self._push_manager = push_manager
self._session_queues = session_queues
self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues
Expand Down Expand Up @@ -141,7 +143,7 @@ async def rpc_inference(
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
session_id = metadata.get("session_id")

Expand Down Expand Up @@ -355,7 +357,7 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont

requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
Expand All @@ -382,7 +384,7 @@ async def rpc_forward_stream(
self._log_request("rpc_forward_stream", requested_uids, context)

requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
Expand Down Expand Up @@ -433,7 +435,7 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon

requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
Expand All @@ -458,7 +460,7 @@ async def rpc_backward_stream(
self._log_request("rpc_backward_stream", requested_uids, context)

requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
Expand All @@ -476,6 +478,12 @@ async def rpc_backward_stream(
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
yield runtime_pb2.ExpertResponse(tensors=[part])

def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "")
if active_adapter and (active_adapter not in self.adapters):
raise KeyError(f"adapter {active_adapter} not found")
return active_adapter

def _serialize_grads(
self,
grads: Sequence[torch.Tensor],
Expand Down
1 change: 1 addition & 0 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(
TransformerConnectionHandler(
dht,
self.module_backends,
adapters=adapters,
dht_prefix=dht_prefix,
push_manager=self.push_manager,
session_queues=session_queues,
Expand Down
55 changes: 48 additions & 7 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import re
import time
from typing import Optional, Sequence
Expand Down Expand Up @@ -118,6 +119,47 @@ def load_peft(
time.sleep(delay)


class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""

ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
_context_active_adapter = ADAPTER_NOT_SET

@staticmethod
@contextlib.contextmanager
def using_adapter(active_adapter: Optional[str]):
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
try:
yield
finally:
AdapterContextMixin._context_active_adapter = prev

@property
def active_adapter(self):
if self._context_active_adapter == self.ADAPTER_NOT_SET:
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
return self._context_active_adapter

@active_adapter.setter
def active_adapter(self, value: Optional[str]):
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""


using_adapter = AdapterContextMixin.using_adapter


class LoraLinear(lora.Linear, AdapterContextMixin):
"""LoRA linear layer that uses adapter selected via using_adapter"""


class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""


class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""


def create_lora_adapter(block, quant_type: QuantType):
for _, module in block.named_modules():
for child_name, child in module.named_children():
Expand All @@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType):
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear8bitLt(
child_name,
lora_wrapped_child = LoraLinear8bitLt(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
Expand All @@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType):
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear4bit(
child_name,
lora_wrapped_child = LoraLinear4bit(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = lora.Linear(
child_name,
lora_wrapped_child = LoraLinear(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.active_adapter = None
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
Expand Down
Loading