Skip to content

Commit

Permalink
Estimate adapter memory overhead in choose_num_blocks() (#346)
Browse files Browse the repository at this point in the history
* estimate adapter memory overhead
* reduce number of heads based on that

---------

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
  • Loading branch information
justheuristic and borzunov committed Jul 13, 2023
1 parent f605f09 commit 010857a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
17 changes: 13 additions & 4 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.peft import estimate_adapter_memory_per_block
from petals.utils.version import get_compatible_model_repo

logger = get_logger(__name__)
Expand Down Expand Up @@ -176,6 +177,8 @@ def __init__(

cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
self.cache_dir = cache_dir
self.adapters = adapters

assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
Expand All @@ -197,7 +200,6 @@ def __init__(
self.alloc_timeout = alloc_timeout
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space

assert isinstance(throughput, float) or throughput in ["auto", "eval"]
Expand All @@ -219,8 +221,6 @@ def __init__(
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay

self.adapters = adapters

self.stop = threading.Event()

def _choose_num_blocks(self) -> int:
Expand Down Expand Up @@ -250,7 +250,16 @@ def _choose_num_blocks(self) -> int:
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size

num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
if adapters:
# Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
from petals.utils.peft import estimate_adapter_memory_per_block

adapter_memory_per_block = estimate_adapter_memory_per_block(
self.block_config, self.torch_dtype, self.adapters, self.cache_dir
)
total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block

num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"

num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
Expand Down
2 changes: 0 additions & 2 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def convert_block(
shard.to(device)

if adapters:
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft

create_lora_adapter(block, quant_type=quant_type)
Expand Down
35 changes: 32 additions & 3 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import os
import re
import time
from typing import List, Optional
from typing import List, Optional, Sequence

os.environ["BITSANDBYTES_NOWELCOME"] = "1"

import bitsandbytes as bnb
import peft
import torch
import torch.nn as nn
import transformers
from accelerate import init_empty_weights
from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.tuners import lora
Expand All @@ -12,6 +19,8 @@
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo

from petals.client.ptune import force_non_empty_weights
from petals.server.block_utils import resolve_block_dtype
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import QuantType

Expand Down Expand Up @@ -194,15 +203,35 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
p.requires_grad = False

if peft_key.endswith(".lora_A.weight"):
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
is_lora_a_loaded = True
elif peft_key.endswith(".lora_A.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
elif peft_key.endswith(".lora_B.weight"):
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
is_lora_b_loaded = True
elif peft_key.endswith(".lora_B.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")

if is_lora_a_loaded and is_lora_b_loaded:
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")


def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)

for adapter in adapters:
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs)
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
add_adapter_to_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
)
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
return adapter_parameters * bytes_per_parameter

0 comments on commit 010857a

Please sign in to comment.