Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in _choose_num_blocks() added in #346 #354

Merged
merged 3 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,13 @@ def __init__(
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")

# For attention cache in GPU or RAM
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8

# For disk cache
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space
self.adapters = adapters

assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
Expand All @@ -197,9 +201,6 @@ def __init__(
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")

self.alloc_timeout = alloc_timeout
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
self.max_disk_space = max_disk_space

assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
Expand Down Expand Up @@ -243,20 +244,24 @@ def _choose_num_blocks(self) -> int:
else:
total_memory = torch.cuda.get_device_properties(self.device).total_memory

block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)

gib = 1024**3
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size

if adapters:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapters is not defined

block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
total_memory_per_block = block_size + self._cache_bytes_per_block
if self.adapters:
# Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
from petals.utils.peft import estimate_adapter_memory_per_block

adapter_memory_per_block = estimate_adapter_memory_per_block(
self.block_config, self.torch_dtype, self.adapters, self.cache_dir
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_dir is a kwarg-only. Also, we need all the rest of disk cache args

total_memory_per_block += estimate_adapter_memory_per_block(
self.block_config,
self.torch_dtype,
self.adapters,
use_auth_token=self.use_auth_token,
cache_dir=self.cache_dir,
max_disk_space=self.max_disk_space,
)
total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block

num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
Expand Down
7 changes: 5 additions & 2 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta


def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs
block_config: transformers.PretrainedConfig,
torch_dtype: Optional[torch.dtype],
adapters: Sequence[str],
**load_peft_kwargs,
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
Expand All @@ -226,7 +229,7 @@ def estimate_adapter_memory_per_block(
create_lora_adapter(block, quant_type=QuantType.NONE)

for adapter in adapters:
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs)
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
add_adapter_to_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
Expand Down
Loading