From 4a05f40786dded4b9fc046d07fca4bf0733e69a5 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Tue, 11 Oct 2022 10:15:25 +0000 Subject: [PATCH] Implement rebalancing criterion --- src/server/block_selection.py | 105 ++++++++++++++++++++++++++++++---- src/server/server.py | 38 ++++++------ 2 files changed, 110 insertions(+), 33 deletions(-) diff --git a/src/server/block_selection.py b/src/server/block_selection.py index 75ee471f2..940d546cd 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -1,18 +1,101 @@ -from typing import List, Optional +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +from hivemind import PeerID, get_logger from src.data_structures import RemoteModuleInfo, ServerState +logger = get_logger(__file__) + + +@dataclass +class Span: + start: int + end: int + throughput: float + + @property + def length(self): + return self.end - self.start + + def move_to(self, new_start: int) -> None: + self.start, self.end = new_start, new_start + self.length + -def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: - throughputs = [] - for module in remote_module_infos: +def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]: + spans = {} + throughputs = np.zeros(len(module_infos)) + for block, module in enumerate(module_infos): if module is None: - throughputs.append(0) continue - throughputs.append( - sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE) - ) - options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)] - best_start = min(options)[1] - return list(range(best_start, best_start + num_blocks)) + for peer_id, server in module.servers.items(): + if server.state == ServerState.OFFLINE: + continue + + if peer_id in spans: + spans[peer_id].start = min(spans[peer_id].start, block) + spans[peer_id].end = max(spans[peer_id].start, block + 1) + else: + spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput) + + throughputs[block] += server.throughput + + return spans, throughputs + + +def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int: + options = ( + (sorted(throughputs[i : i + num_blocks]), i != cur_start, i) + for i in range(0, len(throughputs) - num_blocks + 1) + ) + return min(options)[-1] + + +def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]: + _, throughputs = _compute_spans(module_infos) + start = _choose_best_start(throughputs, num_blocks, None) + return list(range(start, start + num_blocks)) + + +def should_choose_other_blocks( + local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float +) -> bool: + spans, throughputs = _compute_spans(module_infos) + initial_throughput = throughputs.min() + + assert local_peer_id in spans, "Span served by this server is not present in the DHT" + local_span = spans[local_peer_id] + throughputs[local_span.start : local_span.end] -= local_span.throughput + + new_start = _choose_best_start(throughputs, local_span.length, local_span.start) + if local_span.start == new_start: + return False # This server is on its best place already + local_span.move_to(new_start) + + throughputs[local_span.start : local_span.end] += local_span.throughput + + moved = True + while moved: + servers = list(spans.keys()) + np.random.shuffle(servers) + + moved = False + for peer_id in servers: + span = spans[peer_id] + throughputs[span.start : span.end] -= span.throughput + + new_start = _choose_best_start(throughputs, span.length, span.start) + if span.start != new_start: + span.move_to(new_start) + moved = True + + throughputs[span.start : span.end] += span.throughput + + new_throughput = throughputs.min() + balance_quality = initial_throughput / new_throughput + logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%") + + eps = 1e-6 + return balance_quality < min_balance_quality - eps diff --git a/src/server/server.py b/src/server/server.py index 443a67ee4..7e2f99bc9 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -6,6 +6,7 @@ import time from typing import Dict, List, Optional, Sequence, Union +import numpy as np import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file @@ -17,8 +18,8 @@ from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from src.dht_utils import get_remote_module_infos +from src.server import block_selection from src.server.backend import TransformerBackend -from src.server.block_selection import choose_best_blocks from src.server.cache import MemoryCache from src.server.handler import TransformerConnectionHandler from src.server.throughput import get_host_throughput @@ -59,7 +60,8 @@ def __init__( prefetch_batches: int = 1, sender_threads: int = 1, mean_block_selection_delay: float = 0.5, - mean_balance_check_period: float = 300, + mean_balance_check_period: float = 150, + min_balance_quality: float = 0.8, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, *, @@ -122,6 +124,7 @@ def __init__( use_auth_token=use_auth_token, revision=revision, ) + self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" if block_indices is not None: @@ -135,14 +138,13 @@ def __init__( self.strict_block_indices, self.num_blocks = block_indices, num_blocks self.mean_block_selection_delay = mean_block_selection_delay self.mean_balance_check_period = mean_balance_check_period - self._module_infos = None + self.min_balance_quality = min_balance_quality self.stop = threading.Event() if start: self.start() def run(self): - self._update_module_infos() while True: block_indices = self._choose_blocks() self.module_container = ModuleContainer.create( @@ -179,37 +181,29 @@ def run(self): if self.stop.wait(timeout): return - self._update_module_infos() if self._should_choose_other_blocks(): - logger.info("Network is imbalanced, server will load other blocks") + logger.info("Swarm is imbalanced, server will load other blocks") break # Stop serving this set of modules finally: self.module_container.shutdown() - def _update_module_infos(self) -> None: - if self.strict_block_indices: - return # No need for self._module_infos in this case + def _choose_blocks(self) -> List[int]: + if self.strict_block_indices is not None: + return self.strict_block_indices + assert self.num_blocks is not None # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # this delay decreases the probability of a race condition while choosing the best blocks to serve. time.sleep(random.random() * 2 * self.mean_block_selection_delay) - - uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf")) - - def _choose_blocks(self) -> List[int]: - if self.strict_block_indices: - return self.strict_block_indices - - assert self.num_blocks is not None - return choose_best_blocks(self.num_blocks, self._module_infos) + module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) + return block_selection.choose_best_blocks(self.num_blocks, module_infos) def _should_choose_other_blocks(self) -> bool: - if self.strict_block_indices: + if self.strict_block_indices is not None: return False - # TODO: Implement actual algorithm here - return True + module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) + return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality) def shutdown(self): self.stop.set()