From bc5637c9f4e2ba68d25943187f33c83730698e3b Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 14 Jul 2023 17:49:29 +0300 Subject: [PATCH 1/6] faster adapter switching --- src/petals/server/backend.py | 30 ++++++-------------- src/petals/server/handler.py | 7 +++++ src/petals/server/server.py | 1 + src/petals/utils/peft.py | 55 +++++++++++++++++++++++++++++++----- 4 files changed, 64 insertions(+), 29 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 51c6ee075..91cc57deb 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -17,6 +17,7 @@ from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy +from petals.utils.peft import using_global_adapter logger = get_logger(__name__) @@ -82,13 +83,13 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().forward(*inputs) + with using_global_adapter(active_adapter): + return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().backward(*inputs) + with using_global_adapter(active_adapter): + return super().backward(*inputs) @torch.inference_mode() def inference_step( @@ -98,8 +99,9 @@ def inference_step( inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - self.load_adapter_(inference_info.active_adapter) - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_global_adapter( + inference_info.active_adapter + ): self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) @@ -150,22 +152,6 @@ def shutdown(self): for p in self.module.parameters(): p.data = dummy - def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: - """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" - - # Import petals.utils.peft only when necessary to avoid importing bitsandbytes - from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt - - loaded = False - for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter - if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): - layer.active_adapter = active_adapter # empty string for no adapter - if active_adapter in layer.lora_A.keys(): - loaded = True - - if active_adapter and not loaded: - raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded") - def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d9a50258c..6f4ed554b 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -68,6 +68,7 @@ def __init__( dht: DHT, module_backends: Dict[str, TransformerBackend], *, + adapters: Optional[Sequence[str]], dht_prefix: str, push_manager: multiprocessing.managers.SyncManager, session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue @@ -81,6 +82,7 @@ def __init__( for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix + self.adapters = adapters self._push_manager = push_manager self._session_queues = session_queues self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues @@ -142,6 +144,7 @@ async def rpc_inference( requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") active_adapter = metadata.get("active_adapter", "") + assert not active_adapter or active_adapter in self.adapters points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -356,6 +359,7 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = metadata.get("active_adapter", "") + assert not active_adapter or active_adapter in self.adapters points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -383,6 +387,7 @@ async def rpc_forward_stream( requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = metadata.get("active_adapter", "") + assert not active_adapter or active_adapter in self.adapters points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -434,6 +439,7 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = metadata.get("active_adapter", "") + assert not active_adapter or active_adapter in self.adapters points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -459,6 +465,7 @@ async def rpc_backward_stream( requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = metadata.get("active_adapter", "") + assert not active_adapter or active_adapter in self.adapters points = metadata.get("points", 0) assert isinstance( points, (float, int) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e576d000a..2e556267f 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -530,6 +530,7 @@ def __init__( TransformerConnectionHandler( dht, self.module_backends, + adapters=adapters, dht_prefix=dht_prefix, push_manager=self.push_manager, session_queues=session_queues, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 4f5164314..9d5cf9c9b 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,3 +1,4 @@ +import contextlib import re import time from typing import Optional, Sequence @@ -118,6 +119,47 @@ def load_peft( time.sleep(delay) +class GlobalAdapterMixin: + """A mixin that makes LoRA-wrapped linear layers obey a globally set adapter""" + + ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" + GLOBAL_ACTIVE_ADAPTER = None + + @staticmethod + @contextlib.contextmanager + def using_global_adapter(active_adapter: Optional[str]): + prev, GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER = GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER, active_adapter + try: + yield + finally: + GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER = prev + + @property + def active_adapter(self): + if self.GLOBAL_ACTIVE_ADAPTER == self.ADAPTER_NOT_SET: + logger.warning(f"Layer {self} was called without using_global_adapter. This should only be used for debug") + return self.GLOBAL_ACTIVE_ADAPTER + + @active_adapter.setter + def active_adapter(self, value: Optional[str]): + assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed globally, via .using_adapter" "" + + +using_global_adapter = GlobalAdapterMixin.using_global_adapter + + +class GlobalLoraLinear(lora.Linear, GlobalAdapterMixin): + """LoRA linear layer that uses globally selected active adapter""" + + +class GlobalLoraLinear8bitLt(lora.Linear8bitLt, GlobalAdapterMixin): + """LoRA linear 8-bit with outliers that uses globally selected active adapter""" + + +class GlobalLoraLinear4bit(lora.Linear4bit, GlobalAdapterMixin): + """LoRA linear 4-bit that uses globally selected active adapter""" + + def create_lora_adapter(block, quant_type: QuantType): for _, module in block.named_modules(): for child_name, child in module.named_children(): @@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType): "threshold": 6.0, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear8bitLt( - child_name, + lora_wrapped_child = GlobalLoraLinear8bitLt( + GlobalAdapterMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, @@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType): "blocksize": 64, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear4bit( - child_name, + lora_wrapped_child = GlobalLoraLinear4bit( + GlobalAdapterMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) else: bias = hasattr(child, "bias") and child.bias is not None - lora_wrapped_child = lora.Linear( - child_name, + lora_wrapped_child = GlobalLoraLinear( + GlobalAdapterMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, bias=bias, ) if lora_wrapped_child: - lora_wrapped_child.active_adapter = None lora_wrapped_child.weight = child.weight lora_wrapped_child.bias = child.bias for p in lora_wrapped_child.parameters(): From ffe0c5001554904cdbc4d50dc88d1a6200be78d3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 14 Jul 2023 17:53:38 +0300 Subject: [PATCH 2/6] default to special constant --- src/petals/utils/peft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 9d5cf9c9b..56212f3e8 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -123,7 +123,7 @@ class GlobalAdapterMixin: """A mixin that makes LoRA-wrapped linear layers obey a globally set adapter""" ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" - GLOBAL_ACTIVE_ADAPTER = None + GLOBAL_ACTIVE_ADAPTER = ADAPTER_NOT_SET @staticmethod @contextlib.contextmanager From 6c7c0688e4e8e974c0931d2e1e9fcf6b4cb2e804 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 19:40:13 +0300 Subject: [PATCH 3/6] review --- src/petals/server/backend.py | 8 +++---- src/petals/utils/peft.py | 46 ++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 91cc57deb..a28f9da0f 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -17,7 +17,7 @@ from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy -from petals.utils.peft import using_global_adapter +from petals.utils.peft import using_adapter logger = get_logger(__name__) @@ -83,12 +83,12 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - with using_global_adapter(active_adapter): + with using_adapter(active_adapter): return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - with using_global_adapter(active_adapter): + with using_adapter(active_adapter): return super().backward(*inputs) @torch.inference_mode() @@ -99,7 +99,7 @@ def inference_step( inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_global_adapter( + with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_adapter( inference_info.active_adapter ): self._reorder_cache_inplace(cache_tensors, hypo_ids) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 56212f3e8..fb0721f8a 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -119,45 +119,45 @@ def load_peft( time.sleep(delay) -class GlobalAdapterMixin: - """A mixin that makes LoRA-wrapped linear layers obey a globally set adapter""" +class AdapterContextMixin: + """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" - GLOBAL_ACTIVE_ADAPTER = ADAPTER_NOT_SET + _context_active_adapter = ADAPTER_NOT_SET @staticmethod @contextlib.contextmanager - def using_global_adapter(active_adapter: Optional[str]): - prev, GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER = GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER, active_adapter + def using_adapter(active_adapter: Optional[str]): + prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter try: yield finally: - GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER = prev + AdapterContextMixin._context_active_adapter = prev @property def active_adapter(self): - if self.GLOBAL_ACTIVE_ADAPTER == self.ADAPTER_NOT_SET: - logger.warning(f"Layer {self} was called without using_global_adapter. This should only be used for debug") - return self.GLOBAL_ACTIVE_ADAPTER + if self._context_active_adapter == self.ADAPTER_NOT_SET: + logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug") + return self._context_active_adapter @active_adapter.setter def active_adapter(self, value: Optional[str]): - assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed globally, via .using_adapter" "" + assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" -using_global_adapter = GlobalAdapterMixin.using_global_adapter +using_adapter = AdapterContextMixin.using_adapter -class GlobalLoraLinear(lora.Linear, GlobalAdapterMixin): - """LoRA linear layer that uses globally selected active adapter""" +class LoraLinear(lora.Linear, AdapterContextMixin): + """LoRA linear layer that uses adapter selected via using_adapter""" -class GlobalLoraLinear8bitLt(lora.Linear8bitLt, GlobalAdapterMixin): - """LoRA linear 8-bit with outliers that uses globally selected active adapter""" +class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin): + """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" -class GlobalLoraLinear4bit(lora.Linear4bit, GlobalAdapterMixin): - """LoRA linear 4-bit that uses globally selected active adapter""" +class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin): + """LoRA linear 4-bit that uses adapter selected via using_adapter""" def create_lora_adapter(block, quant_type: QuantType): @@ -172,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType): "threshold": 6.0, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = GlobalLoraLinear8bitLt( - GlobalAdapterMixin.ADAPTER_NOT_SET, + lora_wrapped_child = LoraLinear8bitLt( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, @@ -185,16 +185,16 @@ def create_lora_adapter(block, quant_type: QuantType): "blocksize": 64, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = GlobalLoraLinear4bit( - GlobalAdapterMixin.ADAPTER_NOT_SET, + lora_wrapped_child = LoraLinear4bit( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) else: bias = hasattr(child, "bias") and child.bias is not None - lora_wrapped_child = GlobalLoraLinear( - GlobalAdapterMixin.ADAPTER_NOT_SET, + lora_wrapped_child = LoraLinear( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, bias=bias, From e092ec264e26112f95594c5c600c3c9da22f7974 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 20:04:37 +0300 Subject: [PATCH 4/6] DAED HELL SALAD --- src/petals/server/backend.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index a28f9da0f..42205467a 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -17,7 +17,6 @@ from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy -from petals.utils.peft import using_adapter logger = get_logger(__name__) @@ -25,9 +24,15 @@ class TransformerBackend(ModuleBackend): """A wrapper for a transformer block that can process requests for forward, backward and inference""" + _peft_module = None + def __init__( self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs ): + import petals.utils.peft as _peft_module + + self._peft_module = _peft_module + super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config @@ -83,12 +88,12 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - with using_adapter(active_adapter): + with self._peft_module.using_adapter(active_adapter): return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - with using_adapter(active_adapter): + with self._peft_module.using_adapter(active_adapter): return super().backward(*inputs) @torch.inference_mode() @@ -99,9 +104,9 @@ def inference_step( inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_adapter( - inference_info.active_adapter - ): + with self.memory_cache.use_cache( + *inference_info.cache_handles + ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) From 494590cff280542ee59c609ca3d68f1e1bacc9be Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 20:11:33 +0300 Subject: [PATCH 5/6] review --- src/petals/server/handler.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 6f4ed554b..1bfef9306 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -143,8 +143,7 @@ async def rpc_inference( metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = metadata.get("active_adapter", "") - assert not active_adapter or active_adapter in self.adapters + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -358,8 +357,7 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") - assert not active_adapter or active_adapter in self.adapters + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -386,8 +384,7 @@ async def rpc_forward_stream( self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") - assert not active_adapter or active_adapter in self.adapters + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -438,8 +435,7 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") - assert not active_adapter or active_adapter in self.adapters + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -464,8 +460,7 @@ async def rpc_backward_stream( self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") - assert not active_adapter or active_adapter in self.adapters + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -577,6 +572,12 @@ async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) + def _get_active_adapter(self, metadata: dict) -> str: + active_adapter = metadata.get("active_adapter", "") + if active_adapter and (active_adapter not in self.adapters): + raise KeyError(f"adapter {active_adapter} not found") + return active_adapter + async def _rpc_forward( *flat_tensors: torch.Tensor, From fa523fa1d67e99ec895d6aedaaf5d274c89d47e5 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 20:13:54 +0300 Subject: [PATCH 6/6] review --- src/petals/server/handler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 1bfef9306..12fd6eb36 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -478,6 +478,12 @@ async def rpc_backward_stream( for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) + def _get_active_adapter(self, metadata: dict) -> str: + active_adapter = metadata.get("active_adapter", "") + if active_adapter and (active_adapter not in self.adapters): + raise KeyError(f"adapter {active_adapter} not found") + return active_adapter + def _serialize_grads( self, grads: Sequence[torch.Tensor], @@ -572,12 +578,6 @@ async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) - def _get_active_adapter(self, metadata: dict) -> str: - active_adapter = metadata.get("active_adapter", "") - if active_adapter and (active_adapter not in self.adapters): - raise KeyError(f"adapter {active_adapter} not found") - return active_adapter - async def _rpc_forward( *flat_tensors: torch.Tensor,