Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch adapters slightly faster #353

Merged
merged 9 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
review
  • Loading branch information
borzunov authored and Your Name committed Jul 14, 2023
commit 6c7c0688e4e8e974c0931d2e1e9fcf6b4cb2e804
8 changes: 4 additions & 4 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
from petals.utils.peft import using_global_adapter
from petals.utils.peft import using_adapter

logger = get_logger(__name__)

Expand Down Expand Up @@ -83,12 +83,12 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S

def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with using_global_adapter(active_adapter):
with using_adapter(active_adapter):
return super().forward(*inputs)

def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with using_global_adapter(active_adapter):
with using_adapter(active_adapter):
return super().backward(*inputs)

@torch.inference_mode()
Expand All @@ -99,7 +99,7 @@ def inference_step(
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_global_adapter(
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors, using_adapter(
inference_info.active_adapter
):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
Expand Down
46 changes: 23 additions & 23 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,45 +119,45 @@ def load_peft(
time.sleep(delay)


class GlobalAdapterMixin:
"""A mixin that makes LoRA-wrapped linear layers obey a globally set adapter"""
class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""

ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
GLOBAL_ACTIVE_ADAPTER = ADAPTER_NOT_SET
_context_active_adapter = ADAPTER_NOT_SET

@staticmethod
@contextlib.contextmanager
def using_global_adapter(active_adapter: Optional[str]):
prev, GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER = GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER, active_adapter
def using_adapter(active_adapter: Optional[str]):
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
try:
yield
finally:
GlobalAdapterMixin.GLOBAL_ACTIVE_ADAPTER = prev
AdapterContextMixin._context_active_adapter = prev

@property
def active_adapter(self):
if self.GLOBAL_ACTIVE_ADAPTER == self.ADAPTER_NOT_SET:
logger.warning(f"Layer {self} was called without using_global_adapter. This should only be used for debug")
return self.GLOBAL_ACTIVE_ADAPTER
if self._context_active_adapter == self.ADAPTER_NOT_SET:
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
return self._context_active_adapter

@active_adapter.setter
def active_adapter(self, value: Optional[str]):
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed globally, via .using_adapter" ""
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""


using_global_adapter = GlobalAdapterMixin.using_global_adapter
using_adapter = AdapterContextMixin.using_adapter


class GlobalLoraLinear(lora.Linear, GlobalAdapterMixin):
"""LoRA linear layer that uses globally selected active adapter"""
class LoraLinear(lora.Linear, AdapterContextMixin):
"""LoRA linear layer that uses adapter selected via using_adapter"""


class GlobalLoraLinear8bitLt(lora.Linear8bitLt, GlobalAdapterMixin):
"""LoRA linear 8-bit with outliers that uses globally selected active adapter"""
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""


class GlobalLoraLinear4bit(lora.Linear4bit, GlobalAdapterMixin):
"""LoRA linear 4-bit that uses globally selected active adapter"""
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""


def create_lora_adapter(block, quant_type: QuantType):
Expand All @@ -172,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType):
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = GlobalLoraLinear8bitLt(
GlobalAdapterMixin.ADAPTER_NOT_SET,
lora_wrapped_child = LoraLinear8bitLt(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
Expand All @@ -185,16 +185,16 @@ def create_lora_adapter(block, quant_type: QuantType):
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = GlobalLoraLinear4bit(
GlobalAdapterMixin.ADAPTER_NOT_SET,
lora_wrapped_child = LoraLinear4bit(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = GlobalLoraLinear(
GlobalAdapterMixin.ADAPTER_NOT_SET,
lora_wrapped_child = LoraLinear(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
bias=bias,
Expand Down