From c8b85555f434c2f5919798025c668e2e23d621e8 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 22 Jul 2022 12:07:05 +0000 Subject: [PATCH 01/12] Extract ModuleContainer class from Server --- cli/run_server.py | 4 +- src/server/server.py | 284 +++++++++++++++++++++++++++++-------------- 2 files changed, 193 insertions(+), 95 deletions(-) diff --git a/cli/run_server.py b/cli/run_server.py index 4e9e7864c..04738db08 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -43,7 +43,7 @@ def main(): help='Use this many threads to pass results/exceptions from Runtime to Pools') parser.add_argument('--inference_max_length', type=int, default=16384, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') - parser.add_argument('--cache_dir', type=str, default=None, + parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') parser.add_argument('--device', type=str, default=None, required=False, help='all blocks will use this device in torch notation; default: cuda if available else cpu') @@ -104,7 +104,7 @@ def main(): use_auth_token = args.pop("use_auth_token") args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token - server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size) + server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True) try: server.join() diff --git a/src/server/server.py b/src/server/server.py index 1e4fd8da0..3b20ef205 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -4,7 +4,7 @@ import random import threading import time -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, List, Sequence, Union import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time @@ -29,76 +29,14 @@ class Server(threading.Thread): - """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT""" + """ + Runs Server, periodically checks that the network is balanced, + restarts the Server with other layers if the imbalance is significant + """ def __init__( self, - dht: DHT, - module_backends: Dict[str, TransformerBackend], - *, - inference_max_length: int, - num_connection_handlers: int = 8, - throughput: float, - update_period: float = 30, - expiration: Optional[float] = None, - start: bool, - **kwargs, - ): - threading.Thread.__init__(self) - self.dht, self.module_backends = dht, module_backends - self.throughput, self.update_period, self.expiration = throughput, update_period, expiration - self.conn_handlers = [ - TransformerConnectionHandler(dht, self.module_backends, inference_max_length) - for _ in range(num_connection_handlers) - ] - self.runtime = Runtime(self.module_backends, **kwargs) - self.dht_handler_thread = ModuleAnnouncerThread( - self.module_backends, - dht, - throughput=throughput, - update_period=update_period, - expiration=expiration, - daemon=True, - ) - self.checkpoint_saver = None # no need to save checkpoints since we do not change model state - - if start: - self.run_in_background(await_ready=True) - - def run(self): - """ - Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, - runs Runtime (self.runtime) to process incoming requests. - """ - logger.info(f"Serving {len(self.module_backends)} blocks:") - for block_name, backend in self.module_backends.items(): - num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) - parameter_msg = f"{num_parameters} trainable parameters" if num_parameters else "frozen" - logger.info(f"{block_name}: {backend.module.__class__.__name__}, {parameter_msg}") - - if not self.dht.is_alive(): - self.dht.run_in_background(await_ready=True) - if self.module_backends: - self.dht_handler_thread.start() - - if self.checkpoint_saver is not None: - self.checkpoint_saver.start() - - for process in self.conn_handlers: - if not process.is_alive(): - process.start() - process.ready.result() - - try: - self.runtime.run() - finally: - self.shutdown() - - # noinspection PyMethodOverriding - @classmethod - def create( - cls, prefix: Optional[str], converted_model_name_or_path: str, throughput: Union[float, str], @@ -127,10 +65,26 @@ def create( *, start: bool, **kwargs, - ) -> Server: + ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" + + super().__init__() + + self.converted_model_name_or_path = converted_model_name_or_path + self.num_handlers = num_handlers + self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size + self.inference_max_length = inference_max_length + self.cache_dir = cache_dir + self.attn_cache_size = attn_cache_size + self.compression = compression + self.stats_report_interval, self.update_period = stats_report_interval, update_period + self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads + self.use_auth_token = use_auth_token + self.load_in_8bit = load_in_8bit + if custom_module_path is not None: add_custom_models_from_file(custom_module_path) + if prefix is None: prefix = converted_model_name_or_path assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, ( @@ -138,27 +92,37 @@ def create( f"Please specify --prefix manually when starting a server" ) logger.info(f"Automatic dht prefix: {prefix}") + self.prefix = prefix + assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" + if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) + self.expiration = expiration - dht = DHT(initial_peers=initial_peers, start=True, **kwargs) - visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()] + self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs) + visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") device = device or ("cuda" if torch.cuda.is_available() else "cpu") - memory_cache = MemoryCache(device, attn_cache_size) + self.device = device + + self.memory_cache = MemoryCache(device, attn_cache_size) assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput = get_host_throughput(device, force_eval=(throughput == "eval")) + self.throughput = throughput if isinstance(torch_dtype, str): torch_dtype = DTYPE_MAP[torch_dtype] assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + self.torch_dtype = torch_dtype - block_config = BloomConfig.from_pretrained( - converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision + self.block_config = BloomConfig.from_pretrained( + converted_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, ) if block_indices is not None: @@ -175,10 +139,148 @@ def create( time.sleep(random.random() * max_block_selection_delay) assert num_blocks is not None - uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)] - module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf")) + uids = [f"{prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] + module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) block_indices = choose_best_blocks(num_blocks, module_infos) + self.block_indices = block_indices + + self.stop = threading.Event() + if start: + self.start() + + def run(self): + self.module_container = ModuleContainer.create( + dht=self.dht, + prefix=self.prefix, + converted_model_name_or_path=self.converted_model_name_or_path, + block_config=self.block_config, + memory_cache=self.memory_cache, + throughput=self.throughput, + block_indices=self.block_indices, + num_handlers=self.num_handlers, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + inference_max_length=self.inference_max_length, + torch_dtype=self.torch_dtype, + cache_dir=self.cache_dir, + device=self.device, + compression=self.compression, + stats_report_interval=self.stats_report_interval, + update_period=self.update_period, + expiration=self.expiration, + prefetch_batches=self.prefetch_batches, + sender_threads=self.sender_threads, + use_auth_token=self.use_auth_token, + load_in_8bit=self.load_in_8bit, + start=True, + ) + try: + self.stop.wait() + finally: + self.module_container.shutdown() + + def shutdown(self): + self.stop.set() + + self.dht.shutdown() + self.dht.join() + +class ModuleContainer(threading.Thread): + """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT.""" + + def __init__( + self, + dht: DHT, + module_backends: Dict[str, TransformerBackend], + *, + device: torch.device, + num_connection_handlers: int, + throughput: float, + update_period: float, + expiration: Optional[float] = None, + start: bool, + **kwargs, + ): + super().__init__() + + self.dht, self.module_backends = dht, module_backends + self.throughput, self.update_period, self.expiration = throughput, update_period, expiration + self.conn_handlers = [ + TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) + ] + self.runtime = Runtime(self.module_backends, device=device, **kwargs) + self.dht_handler_thread = ModuleAnnouncerThread( + self.module_backends, + dht, + throughput=throughput, + update_period=update_period, + expiration=expiration, + daemon=True, + ) + self.checkpoint_saver = None # no need to save checkpoints since we do not change model state + + if start: + self.run_in_background(await_ready=True) + + def run(self): + """ + Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers, + runs Runtime (self.runtime) to process incoming requests. + """ + logger.info(f"Serving {len(self.module_backends)} blocks:") + for expert_name, backend in self.module_backends.items(): + num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad) + logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters") + + if not self.dht.is_alive(): + self.dht.run_in_background(await_ready=True) + + if self.module_backends: + self.dht_handler_thread.start() + + if self.checkpoint_saver is not None: + self.checkpoint_saver.start() + + for process in self.conn_handlers: + if not process.is_alive(): + process.start() + process.ready.result() + + try: + self.runtime.run() + finally: + self.shutdown() + + # noinspection PyMethodOverriding + @classmethod + def create( + cls, + *, + dht: DHT, + prefix: str, + converted_model_name_or_path: str, + block_config: BloomConfig, + memory_cache: MemoryCache, + throughput: float, + block_indices: List[int], + num_handlers: Optional[int], + min_batch_size: int, + max_batch_size: int, + inference_max_length: int, + torch_dtype: torch.dtype, + cache_dir: Optional[str], + device: Union[str, torch.device], + compression: CompressionType, + stats_report_interval: Optional[int], + update_period: float, + expiration: Optional[float], + prefetch_batches: int, + sender_threads: int, + use_auth_token: Optional[str], + load_in_8bit: bool, + start: bool, + ) -> ModuleContainer: module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] declare_active_modules( dht, @@ -245,33 +347,36 @@ def create( def run_in_background(self, await_ready=True, timeout=None): """ - Starts Server in a background thread. if await_ready, this method will wait until background server + Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container is ready to process incoming requests or for :timeout: seconds max. """ self.start() if await_ready and not self.ready.wait(timeout=timeout): - raise TimeoutError("Server didn't notify .ready in {timeout} seconds") + raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds") @property def ready(self) -> mp.synchronize.Event: """ - An event (multiprocessing.Event) that is set when the server is ready to process requests. + An event (multiprocessing.Event) that is set when the container is ready to process requests. Example ======= - >>> server.start() - >>> server.ready.wait(timeout=10) - >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds") + >>> container.start() + >>> container.ready.wait(timeout=10) + >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds") """ return self.runtime.ready # mp.Event that is true if self is ready to process batches def shutdown(self): """ - Gracefully terminate the server, process-safe. - Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes. + Gracefully terminate the container, process-safe. + Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). """ if self.module_backends: + self.dht_handler_thread.stop.set() + self.dht_handler_thread.join() + declare_active_modules( self.dht, self.module_backends.keys(), @@ -288,25 +393,18 @@ def shutdown(self): process.join() logger.debug("Connection handlers terminated") - if self.module_backends: - self.dht_handler_thread.stop.set() - self.dht_handler_thread.join() - if self.checkpoint_saver is not None: self.checkpoint_saver.stop.set() self.checkpoint_saver.join() - self.dht.shutdown() - self.dht.join() - logger.debug(f"Shutting down runtime") - self.runtime.shutdown() - logger.info("Server shut down succesfully") + + logger.info("Module container shut down succesfully") class ModuleAnnouncerThread(threading.Thread): - """Periodically announces that this server hosts the specified modules, visible to all DHT peers""" + """Periodically announces that this container hosts the specified modules, visible to all DHT peers""" def __init__( self, From 325ff0cef9758a33150701f08acb0c25bc26fd91 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 22 Jul 2022 12:44:39 +0000 Subject: [PATCH 02/12] Draft balance check logic --- src/server/server.py | 99 ++++++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 41 deletions(-) diff --git a/src/server/server.py b/src/server/server.py index 3b20ef205..c993fa66c 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -60,6 +60,7 @@ def __init__( prefetch_batches: int = 1, sender_threads: int = 1, max_block_selection_delay: float = 1, + max_balance_check_period: float = 600, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -94,8 +95,6 @@ def __init__( logger.info(f"Automatic dht prefix: {prefix}") self.prefix = prefix - assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" - if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) self.expiration = expiration @@ -125,6 +124,7 @@ def __init__( revision=revision, ) + assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" if block_indices is not None: try: first_block_index, last_block_index = block_indices.split(":") @@ -133,51 +133,68 @@ def __init__( logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") raise block_indices = range(first_block_index, last_block_index) - else: - # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, - # this delay decreases the probability of a race condition while choosing the best blocks to serve. - time.sleep(random.random() * max_block_selection_delay) - - assert num_blocks is not None - uids = [f"{prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) - block_indices = choose_best_blocks(num_blocks, module_infos) - self.block_indices = block_indices + self.block_indices, self.num_blocks = block_indices, num_blocks + self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period self.stop = threading.Event() if start: self.start() def run(self): - self.module_container = ModuleContainer.create( - dht=self.dht, - prefix=self.prefix, - converted_model_name_or_path=self.converted_model_name_or_path, - block_config=self.block_config, - memory_cache=self.memory_cache, - throughput=self.throughput, - block_indices=self.block_indices, - num_handlers=self.num_handlers, - min_batch_size=self.min_batch_size, - max_batch_size=self.max_batch_size, - inference_max_length=self.inference_max_length, - torch_dtype=self.torch_dtype, - cache_dir=self.cache_dir, - device=self.device, - compression=self.compression, - stats_report_interval=self.stats_report_interval, - update_period=self.update_period, - expiration=self.expiration, - prefetch_batches=self.prefetch_batches, - sender_threads=self.sender_threads, - use_auth_token=self.use_auth_token, - load_in_8bit=self.load_in_8bit, - start=True, - ) - try: - self.stop.wait() - finally: - self.module_container.shutdown() + while True: + block_indices = self._choose_blocks() + self.module_container = ModuleContainer.create( + dht=self.dht, + prefix=self.prefix, + converted_model_name_or_path=self.converted_model_name_or_path, + block_config=self.block_config, + memory_cache=self.memory_cache, + throughput=self.throughput, + block_indices=block_indices, + num_handlers=self.num_handlers, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + inference_max_length=self.inference_max_length, + torch_dtype=self.torch_dtype, + cache_dir=self.cache_dir, + device=self.device, + compression=self.compression, + stats_report_interval=self.stats_report_interval, + update_period=self.update_period, + expiration=self.expiration, + prefetch_batches=self.prefetch_batches, + sender_threads=self.sender_threads, + use_auth_token=self.use_auth_token, + load_in_8bit=self.load_in_8bit, + start=True, + ) + try: + self.module_container.ready.wait() + + while True: + timeout = random.random() * self.max_balance_check_period + if self.stop.wait(timeout): + return + if self._should_choose_other_blocks(): + break # Stop serving this set of modules + finally: + self.module_container.shutdown() + + def _choose_blocks(self) -> List[int]: + if self.block_indices is not None: + return self.block_indices + + # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, + # this delay decreases the probability of a race condition while choosing the best blocks to serve. + time.sleep(random.random() * self.max_block_selection_delay) + + assert self.num_blocks is not None + uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] + module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) + return choose_best_blocks(self.num_blocks, module_infos) + + def _should_choose_other_blocks(self) -> bool: + return False def shutdown(self): self.stop.set() From 74c1b1129bdfea9880d5be91e9e38351bb0d9f96 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Sat, 8 Oct 2022 05:17:32 +0000 Subject: [PATCH 03/12] Update docstrings --- src/server/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/server/server.py b/src/server/server.py index c993fa66c..340be303b 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -30,8 +30,8 @@ class Server(threading.Thread): """ - Runs Server, periodically checks that the network is balanced, - restarts the Server with other layers if the imbalance is significant + Runs ModuleContainer, periodically checks that the network is balanced, + restarts the ModuleContainer with other layers if the imbalance is significant """ def __init__( @@ -243,7 +243,7 @@ def __init__( def run(self): """ Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers, - runs Runtime (self.runtime) to process incoming requests. + runs hivemind.Runtime (self.runtime) to process incoming requests. """ logger.info(f"Serving {len(self.module_backends)} blocks:") for expert_name, backend in self.module_backends.items(): From fdb0f7371bc33f0ec29e47dc3e2c42d7e5343afc Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Sat, 8 Oct 2022 05:38:13 +0000 Subject: [PATCH 04/12] Set mean delays instead of max delays --- src/server/server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/server/server.py b/src/server/server.py index 340be303b..746460e4d 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -36,7 +36,6 @@ class Server(threading.Thread): def __init__( self, - prefix: Optional[str], converted_model_name_or_path: str, throughput: Union[float, str], @@ -59,8 +58,8 @@ def __init__( expiration: Optional[float] = None, prefetch_batches: int = 1, sender_threads: int = 1, - max_block_selection_delay: float = 1, - max_balance_check_period: float = 600, + mean_block_selection_delay: float = 0.5, + mean_balance_check_period: float = 300, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -134,7 +133,8 @@ def __init__( raise block_indices = range(first_block_index, last_block_index) self.block_indices, self.num_blocks = block_indices, num_blocks - self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period + self.mean_block_selection_delay = mean_block_selection_delay + self.mean_balance_check_period = mean_balance_check_period self.stop = threading.Event() if start: @@ -172,7 +172,7 @@ def run(self): self.module_container.ready.wait() while True: - timeout = random.random() * self.max_balance_check_period + timeout = random.random() * 2 * self.mean_balance_check_period if self.stop.wait(timeout): return if self._should_choose_other_blocks(): @@ -186,7 +186,7 @@ def _choose_blocks(self) -> List[int]: # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # this delay decreases the probability of a race condition while choosing the best blocks to serve. - time.sleep(random.random() * self.max_block_selection_delay) + time.sleep(random.random() * 2 * self.mean_block_selection_delay) assert self.num_blocks is not None uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] From bafd66ad509541d52c3f87c356e4d48b876e839e Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Sat, 8 Oct 2022 05:42:56 +0000 Subject: [PATCH 05/12] Add allow_rebalancing flag --- src/server/server.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/server/server.py b/src/server/server.py index 746460e4d..bcee022b0 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -41,6 +41,7 @@ def __init__( throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, + allow_rebalancing: bool = True, num_handlers: int = 8, min_batch_size: int = 1, max_batch_size: int = 4096, @@ -133,6 +134,7 @@ def __init__( raise block_indices = range(first_block_index, last_block_index) self.block_indices, self.num_blocks = block_indices, num_blocks + self.allow_rebalancing = allow_rebalancing self.mean_block_selection_delay = mean_block_selection_delay self.mean_balance_check_period = mean_balance_check_period @@ -194,7 +196,11 @@ def _choose_blocks(self) -> List[int]: return choose_best_blocks(self.num_blocks, module_infos) def _should_choose_other_blocks(self) -> bool: - return False + if not self.allow_rebalancing: + return False + + # TODO: Implement actual algorithm here + return True def shutdown(self): self.stop.set() From b5d54f42c0f33cc5ee7bb8ed2860a55837a8ef9d Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Sat, 8 Oct 2022 06:05:06 +0000 Subject: [PATCH 06/12] Fix errors --- src/server/server.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/server/server.py b/src/server/server.py index bcee022b0..422fa2b24 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -4,7 +4,7 @@ import random import threading import time -from typing import Dict, Optional, List, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time @@ -60,7 +60,7 @@ def __init__( prefetch_batches: int = 1, sender_threads: int = 1, mean_block_selection_delay: float = 0.5, - mean_balance_check_period: float = 300, + mean_balance_check_period: float = 60, # TODO: use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -178,6 +178,7 @@ def run(self): if self.stop.wait(timeout): return if self._should_choose_other_blocks(): + logger.info("Network is imbalanced, server will load other blocks") break # Stop serving this set of modules finally: self.module_container.shutdown() @@ -217,7 +218,7 @@ def __init__( dht: DHT, module_backends: Dict[str, TransformerBackend], *, - device: torch.device, + inference_max_length: int, num_connection_handlers: int, throughput: float, update_period: float, @@ -230,9 +231,10 @@ def __init__( self.dht, self.module_backends = dht, module_backends self.throughput, self.update_period, self.expiration = throughput, update_period, expiration self.conn_handlers = [ - TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) + TransformerConnectionHandler(dht, self.module_backends, inference_max_length) + for _ in range(num_connection_handlers) ] - self.runtime = Runtime(self.module_backends, device=device, **kwargs) + self.runtime = Runtime(self.module_backends, **kwargs) self.dht_handler_thread = ModuleAnnouncerThread( self.module_backends, dht, From f3ea120c8194760617acb71bee887f198904e1a9 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 11 Oct 2022 05:03:14 +0000 Subject: [PATCH 07/12] Fix ModuleContainer.shutdown() and its usages --- src/server/server.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/server/server.py b/src/server/server.py index 422fa2b24..f6736fab4 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -60,7 +60,7 @@ def __init__( prefetch_batches: int = 1, sender_threads: int = 1, mean_block_selection_delay: float = 0.5, - mean_balance_check_period: float = 60, # TODO: + mean_balance_check_period: float = 300, # TODO: use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -175,6 +175,7 @@ def run(self): while True: timeout = random.random() * 2 * self.mean_balance_check_period + # TODO: Follow ModuleContainer status (to restart/stop if it crashes) if self.stop.wait(timeout): return if self._should_choose_other_blocks(): @@ -251,7 +252,7 @@ def __init__( def run(self): """ Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers, - runs hivemind.Runtime (self.runtime) to process incoming requests. + runs Runtime (self.runtime) to process incoming requests. """ logger.info(f"Serving {len(self.module_backends)} blocks:") for expert_name, backend in self.module_backends.items(): @@ -267,15 +268,10 @@ def run(self): if self.checkpoint_saver is not None: self.checkpoint_saver.start() - for process in self.conn_handlers: - if not process.is_alive(): - process.start() - process.ready.result() + for handler in self.conn_handlers: + handler.run_in_background() - try: - self.runtime.run() - finally: - self.shutdown() + self.runtime.run() # noinspection PyMethodOverriding @classmethod @@ -413,9 +409,8 @@ def shutdown(self): self.ready.clear() - for process in self.conn_handlers: - process.terminate() - process.join() + for handler in self.conn_handlers: + handler.shutdown() logger.debug("Connection handlers terminated") if self.checkpoint_saver is not None: From e965719b58226a282200b8251fad60ee374d31a9 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 11 Oct 2022 06:46:07 +0000 Subject: [PATCH 08/12] Simplify rebalancing options --- src/server/server.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/server/server.py b/src/server/server.py index f6736fab4..443a67ee4 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -41,7 +41,6 @@ def __init__( throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, - allow_rebalancing: bool = True, num_handlers: int = 8, min_batch_size: int = 1, max_batch_size: int = 4096, @@ -60,7 +59,7 @@ def __init__( prefetch_batches: int = 1, sender_threads: int = 1, mean_block_selection_delay: float = 0.5, - mean_balance_check_period: float = 300, # TODO: + mean_balance_check_period: float = 300, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -133,16 +132,17 @@ def __init__( logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") raise block_indices = range(first_block_index, last_block_index) - self.block_indices, self.num_blocks = block_indices, num_blocks - self.allow_rebalancing = allow_rebalancing + self.strict_block_indices, self.num_blocks = block_indices, num_blocks self.mean_block_selection_delay = mean_block_selection_delay self.mean_balance_check_period = mean_balance_check_period + self._module_infos = None self.stop = threading.Event() if start: self.start() def run(self): + self._update_module_infos() while True: block_indices = self._choose_blocks() self.module_container = ModuleContainer.create( @@ -178,27 +178,34 @@ def run(self): # TODO: Follow ModuleContainer status (to restart/stop if it crashes) if self.stop.wait(timeout): return + + self._update_module_infos() if self._should_choose_other_blocks(): logger.info("Network is imbalanced, server will load other blocks") break # Stop serving this set of modules finally: self.module_container.shutdown() - def _choose_blocks(self) -> List[int]: - if self.block_indices is not None: - return self.block_indices + def _update_module_infos(self) -> None: + if self.strict_block_indices: + return # No need for self._module_infos in this case # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # this delay decreases the probability of a race condition while choosing the best blocks to serve. time.sleep(random.random() * 2 * self.mean_block_selection_delay) - assert self.num_blocks is not None uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) - return choose_best_blocks(self.num_blocks, module_infos) + self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) + + def _choose_blocks(self) -> List[int]: + if self.strict_block_indices: + return self.strict_block_indices + + assert self.num_blocks is not None + return choose_best_blocks(self.num_blocks, self._module_infos) def _should_choose_other_blocks(self) -> bool: - if not self.allow_rebalancing: + if self.strict_block_indices: return False # TODO: Implement actual algorithm here From 4a05f40786dded4b9fc046d07fca4bf0733e69a5 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 11 Oct 2022 10:15:25 +0000 Subject: [PATCH 09/12] Implement rebalancing criterion --- src/server/block_selection.py | 105 ++++++++++++++++++++++++++++++---- src/server/server.py | 38 ++++++------ 2 files changed, 110 insertions(+), 33 deletions(-) diff --git a/src/server/block_selection.py b/src/server/block_selection.py index 75ee471f2..940d546cd 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -1,18 +1,101 @@ -from typing import List, Optional +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +from hivemind import PeerID, get_logger from src.data_structures import RemoteModuleInfo, ServerState +logger = get_logger(__file__) + + +@dataclass +class Span: + start: int + end: int + throughput: float + + @property + def length(self): + return self.end - self.start + + def move_to(self, new_start: int) -> None: + self.start, self.end = new_start, new_start + self.length + -def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: - throughputs = [] - for module in remote_module_infos: +def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: + spans = {} + throughputs = np.zeros(len(module_infos)) + for block, module in enumerate(module_infos): if module is None: - throughputs.append(0) continue - throughputs.append( - sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE) - ) - options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)] - best_start = min(options)[1] - return list(range(best_start, best_start + num_blocks)) + for peer_id, server in module.servers.items(): + if server.state == ServerState.OFFLINE: + continue + + if peer_id in spans: + spans[peer_id].start = min(spans[peer_id].start, block) + spans[peer_id].end = max(spans[peer_id].start, block + 1) + else: + spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput) + + throughputs[block] += server.throughput + + return spans, throughputs + + +def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int: + options = ( + (sorted(throughputs[i : i + num_blocks]), i != cur_start, i) + for i in range(0, len(throughputs) - num_blocks + 1) + ) + return min(options)[-1] + + +def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: + _, throughputs = _compute_spans(module_infos) + start = _choose_best_start(throughputs, num_blocks, None) + return list(range(start, start + num_blocks)) + + +def should_choose_other_blocks( + local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float +) -> bool: + spans, throughputs = _compute_spans(module_infos) + initial_throughput = throughputs.min() + + assert local_peer_id in spans, "Span served by this server is not present in the DHT" + local_span = spans[local_peer_id] + throughputs[local_span.start : local_span.end] -= local_span.throughput + + new_start = _choose_best_start(throughputs, local_span.length, local_span.start) + if local_span.start == new_start: + return False # This server is on its best place already + local_span.move_to(new_start) + + throughputs[local_span.start : local_span.end] += local_span.throughput + + moved = True + while moved: + servers = list(spans.keys()) + np.random.shuffle(servers) + + moved = False + for peer_id in servers: + span = spans[peer_id] + throughputs[span.start : span.end] -= span.throughput + + new_start = _choose_best_start(throughputs, span.length, span.start) + if span.start != new_start: + span.move_to(new_start) + moved = True + + throughputs[span.start : span.end] += span.throughput + + new_throughput = throughputs.min() + balance_quality = initial_throughput / new_throughput + logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%") + + eps = 1e-6 + return balance_quality < min_balance_quality - eps diff --git a/src/server/server.py b/src/server/server.py index 443a67ee4..7e2f99bc9 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -6,6 +6,7 @@ import time from typing import Dict, List, Optional, Sequence, Union +import numpy as np import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -17,8 +18,8 @@ from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from src.dht_utils import get_remote_module_infos +from src.server import block_selection from src.server.backend import TransformerBackend -from src.server.block_selection import choose_best_blocks from src.server.cache import MemoryCache from src.server.handler import TransformerConnectionHandler from src.server.throughput import get_host_throughput @@ -59,7 +60,8 @@ def __init__( prefetch_batches: int = 1, sender_threads: int = 1, mean_block_selection_delay: float = 0.5, - mean_balance_check_period: float = 300, + mean_balance_check_period: float = 150, + min_balance_quality: float = 0.8, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -122,6 +124,7 @@ def __init__( use_auth_token=use_auth_token, revision=revision, ) + self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" if block_indices is not None: @@ -135,14 +138,13 @@ def __init__( self.strict_block_indices, self.num_blocks = block_indices, num_blocks self.mean_block_selection_delay = mean_block_selection_delay self.mean_balance_check_period = mean_balance_check_period - self._module_infos = None + self.min_balance_quality = min_balance_quality self.stop = threading.Event() if start: self.start() def run(self): - self._update_module_infos() while True: block_indices = self._choose_blocks() self.module_container = ModuleContainer.create( @@ -179,37 +181,29 @@ def run(self): if self.stop.wait(timeout): return - self._update_module_infos() if self._should_choose_other_blocks(): - logger.info("Network is imbalanced, server will load other blocks") + logger.info("Swarm is imbalanced, server will load other blocks") break # Stop serving this set of modules finally: self.module_container.shutdown() - def _update_module_infos(self) -> None: - if self.strict_block_indices: - return # No need for self._module_infos in this case + def _choose_blocks(self) -> List[int]: + if self.strict_block_indices is not None: + return self.strict_block_indices + assert self.num_blocks is not None # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # this delay decreases the probability of a race condition while choosing the best blocks to serve. time.sleep(random.random() * 2 * self.mean_block_selection_delay) - - uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) - - def _choose_blocks(self) -> List[int]: - if self.strict_block_indices: - return self.strict_block_indices - - assert self.num_blocks is not None - return choose_best_blocks(self.num_blocks, self._module_infos) + module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) + return block_selection.choose_best_blocks(self.num_blocks, module_infos) def _should_choose_other_blocks(self) -> bool: - if self.strict_block_indices: + if self.strict_block_indices is not None: return False - # TODO: Implement actual algorithm here - return True + module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) + return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality) def shutdown(self): self.stop.set() From 1cf15a7f482968063638f26b3f6dc25a1158d4e1 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Wed, 12 Oct 2022 09:57:19 +0000 Subject: [PATCH 10/12] Disable rebalancing by default, address some of review comments --- src/server/block_selection.py | 2 ++ src/server/server.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/server/block_selection.py b/src/server/block_selection.py index 940d546cd..bfe048249 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -6,6 +6,8 @@ from src.data_structures import RemoteModuleInfo, ServerState +__all__ = ['choose_best_blocks', 'should_choose_other_blocks'] + logger = get_logger(__file__) diff --git a/src/server/server.py b/src/server/server.py index 7e2f99bc9..5879ae5f8 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -61,7 +61,7 @@ def __init__( sender_threads: int = 1, mean_block_selection_delay: float = 0.5, mean_balance_check_period: float = 150, - min_balance_quality: float = 0.8, + min_balance_quality: float = 0.0, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, From 4e26b2799d553734a67d316b4269610f3eca78c7 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Wed, 12 Oct 2022 10:14:13 +0000 Subject: [PATCH 11/12] Add rebalancing options to CLI args --- cli/run_server.py | 7 +++++++ src/server/server.py | 8 ++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cli/run_server.py b/cli/run_server.py index 04738db08..80598a0df 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -79,6 +79,13 @@ def main(): parser.add_argument('--custom_module_path', type=str, required=False, help='Path of a file with custom nn.modules, wrapped into special decorator') parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P') + + parser.add_argument("--min_balance_quality", type=float, default=0.0, + help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) " + "goes below this threshold. Default: rebalancing is disabled") + parser.add_argument("--mean_balance_check_period", type=float, default=150, + help="Check the swarm's balance every N seconds (and rebalance it if necessary)") + parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.') diff --git a/src/server/server.py b/src/server/server.py index 5879ae5f8..a0c1c09ea 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -59,9 +59,9 @@ def __init__( expiration: Optional[float] = None, prefetch_batches: int = 1, sender_threads: int = 1, - mean_block_selection_delay: float = 0.5, - mean_balance_check_period: float = 150, min_balance_quality: float = 0.0, + mean_balance_check_period: float = 150, + mean_block_selection_delay: float = 0.5, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -136,9 +136,9 @@ def __init__( raise block_indices = range(first_block_index, last_block_index) self.strict_block_indices, self.num_blocks = block_indices, num_blocks - self.mean_block_selection_delay = mean_block_selection_delay - self.mean_balance_check_period = mean_balance_check_period self.min_balance_quality = min_balance_quality + self.mean_balance_check_period = mean_balance_check_period + self.mean_block_selection_delay = mean_block_selection_delay self.stop = threading.Event() if start: From 7dc7da3e2eeda3cb4bea34dafd88738d0458bcda Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Wed, 12 Oct 2022 10:18:43 +0000 Subject: [PATCH 12/12] black --- src/server/block_selection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/block_selection.py b/src/server/block_selection.py index bfe048249..fe926e91a 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -6,7 +6,7 @@ from src.data_structures import RemoteModuleInfo, ServerState -__all__ = ['choose_best_blocks', 'should_choose_other_blocks'] +__all__ = ["choose_best_blocks", "should_choose_other_blocks"] logger = get_logger(__file__)