From 86d1afb7b8e1f069ec05882a85b44dd8296066fb Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 13 Jul 2023 02:42:58 +0300 Subject: [PATCH 1/5] estimate adapter memory overhead --- src/petals/server/server.py | 13 +++++++++---- src/petals/utils/peft.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 643bf1bdc..1268c830f 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -30,6 +30,7 @@ from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.peft import estimate_adapter_memory_per_block from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -176,6 +177,8 @@ def __init__( cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 + self.cache_dir = cache_dir + self.adapters = adapters assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: @@ -197,7 +200,6 @@ def __init__( self.alloc_timeout = alloc_timeout if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR - self.cache_dir = cache_dir self.max_disk_space = max_disk_space assert isinstance(throughput, float) or throughput in ["auto", "eval"] @@ -219,8 +221,6 @@ def __init__( self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay - self.adapters = adapters - self.stop = threading.Event() def _choose_num_blocks(self) -> int: @@ -250,7 +250,12 @@ def _choose_num_blocks(self) -> int: # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size - num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) + adapter_memory_per_block = estimate_adapter_memory_per_block( + self.block_config, self.torch_dtype, self.adapters, self.cache_dir + ) + total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block + + num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" num_blocks = min(num_blocks, self.block_config.num_hidden_layers) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index c551f9757..5cef4e789 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,9 +1,13 @@ import re import time -from typing import List, Optional +from typing import List, Optional, Sequence import bitsandbytes as bnb +import peft +import torch import torch.nn as nn +import transformers +from accelerate import init_empty_weights from hivemind.utils.logging import get_logger from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url from peft.tuners import lora @@ -12,6 +16,8 @@ from safetensors.torch import load_file from transformers.utils import get_file_from_repo +from petals.client.ptune import force_non_empty_weights +from petals.server.block_utils import resolve_block_dtype from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.misc import QuantType @@ -194,15 +200,35 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta p.requires_grad = False if peft_key.endswith(".lora_A.weight"): - child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] + child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_a_loaded = True elif peft_key.endswith(".lora_A.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") elif peft_key.endswith(".lora_B.weight"): - child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key] + child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_b_loaded = True elif peft_key.endswith(".lora_B.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") if is_lora_a_loaded and is_lora_b_loaded: logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully") + + +def estimate_adapter_memory_per_block( + block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs +) -> int: + """Get the number of extra bytes used to store a set of adapters per given block""" + with init_empty_weights(include_buffers=True): + block = block_config.block_class(block_config) + base_block_parameters = sum(p.numel() for p in block.parameters()) + create_lora_adapter(block, quant_type=QuantType.NONE) + + for adapter in adapters: + peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs) + assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now" + add_adapter_to_block( + block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict + ) + adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters + bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8 + return adapter_parameters * bytes_per_parameter From d18e2d14fba6a50f66932ed07a0f850469a0a659 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 14 Jul 2023 00:18:32 +0300 Subject: [PATCH 2/5] Update src/petals/server/server.py Co-authored-by: Alexander Borzunov --- src/petals/server/server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 1268c830f..a9cdbbe57 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -250,9 +250,13 @@ def _choose_num_blocks(self) -> int: # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size - adapter_memory_per_block = estimate_adapter_memory_per_block( - self.block_config, self.torch_dtype, self.adapters, self.cache_dir - ) + if adapters: + # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes + from petals.utils.peft import estimate_adapter_memory_per_block + + adapter_memory_per_block = estimate_adapter_memory_per_block( + self.block_config, self.torch_dtype, self.adapters, self.cache_dir + ) total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block) From 58a92cbb64c57262d66a9823b0d9979f4d87c237 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 14 Jul 2023 00:19:02 +0300 Subject: [PATCH 3/5] Update src/petals/utils/peft.py Co-authored-by: Alexander Borzunov --- src/petals/utils/peft.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 5cef4e789..0b50472ee 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,7 +1,10 @@ import re import time +import os from typing import List, Optional, Sequence +os.environ["BITSANDBYTES_NOWELCOME"] = "1" + import bitsandbytes as bnb import peft import torch From 832cb51cf033c9a91c5fb237e09c24fed5b6e489 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 14 Jul 2023 00:22:52 +0300 Subject: [PATCH 4/5] black --- src/petals/server/server.py | 2 +- src/petals/utils/peft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a9cdbbe57..e576d000a 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -253,7 +253,7 @@ def _choose_num_blocks(self) -> int: if adapters: # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes from petals.utils.peft import estimate_adapter_memory_per_block - + adapter_memory_per_block = estimate_adapter_memory_per_block( self.block_config, self.torch_dtype, self.adapters, self.cache_dir ) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 0b50472ee..7eb72ef62 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,6 +1,6 @@ +import os import re import time -import os from typing import List, Optional, Sequence os.environ["BITSANDBYTES_NOWELCOME"] = "1" From 49b6cc86a35ad65e73f79944437a20682e2a6741 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 00:24:40 +0300 Subject: [PATCH 5/5] review --- src/petals/utils/convert_block.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 5c04092a8..b75709d86 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -55,8 +55,6 @@ def convert_block( shard.to(device) if adapters: - # Import petals.utils.peft only when necessary to avoid importing bitsandbytes - os.environ["BITSANDBYTES_NOWELCOME"] = "1" from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft create_lora_adapter(block, quant_type=quant_type)