From 80e6cb3bb41c58bd1410176be715f4c364b07745 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 15:33:17 +0000 Subject: [PATCH 1/3] Fix _choose_num_blocks() --- src/petals/server/server.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index be580d780..686c8c2e6 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -174,9 +174,13 @@ def __init__( self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") + # For attention cache in GPU or RAM cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 + + # For disk cache self.cache_dir = cache_dir + self.max_disk_space = max_disk_space self.adapters = adapters assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" @@ -197,9 +201,6 @@ def __init__( logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") self.alloc_timeout = alloc_timeout - if cache_dir is None: - cache_dir = DEFAULT_CACHE_DIR - self.max_disk_space = max_disk_space assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: @@ -249,12 +250,17 @@ def _choose_num_blocks(self) -> int: # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size - if adapters: + if self.adapters: # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes from petals.utils.peft import estimate_adapter_memory_per_block adapter_memory_per_block = estimate_adapter_memory_per_block( - self.block_config, self.torch_dtype, self.adapters, self.cache_dir + self.block_config, + self.torch_dtype, + self.adapters, + use_auth_token=self.use_auth_token, + cache_dir=self.cache_dir, + max_disk_space=self.max_disk_space, ) total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block From c73ab431ba57fc3d83266514adf2088973d18860 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 15:42:49 +0000 Subject: [PATCH 2/3] Rename **kwargs -> **load_peft_kwargs in estimate_adapter_memory_per_block() --- src/petals/utils/peft.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 4f5164314..c537a32bc 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -217,7 +217,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta def estimate_adapter_memory_per_block( - block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs + block_config: transformers.PretrainedConfig, + torch_dtype: Optional[torch.dtype], + adapters: Sequence[str], + **load_peft_kwargs, ) -> int: """Get the number of extra bytes used to store a set of adapters per given block""" with init_empty_weights(include_buffers=True): @@ -226,7 +229,7 @@ def estimate_adapter_memory_per_block( create_lora_adapter(block, quant_type=QuantType.NONE) for adapter in adapters: - peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs) + peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs) assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now" add_adapter_to_block( block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict From 9c21f044e00423126cdf7f22e84c00831d5df606 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 14 Jul 2023 18:23:51 +0000 Subject: [PATCH 3/3] Fix _choose_num_blocks() once again --- src/petals/server/server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 686c8c2e6..c90ae4498 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -244,17 +244,17 @@ def _choose_num_blocks(self) -> int: else: total_memory = torch.cuda.get_device_properties(self.device).total_memory - block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) - gib = 1024**3 # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size + block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) + total_memory_per_block = block_size + self._cache_bytes_per_block if self.adapters: # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes from petals.utils.peft import estimate_adapter_memory_per_block - adapter_memory_per_block = estimate_adapter_memory_per_block( + total_memory_per_block += estimate_adapter_memory_per_block( self.block_config, self.torch_dtype, self.adapters, @@ -262,7 +262,6 @@ def _choose_num_blocks(self) -> int: cache_dir=self.cache_dir, max_disk_space=self.max_disk_space, ) - total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"