Skip to content

Commit

Permalink
Rebalance swarm when necessary (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Oct 12, 2022
1 parent 640bbc3 commit 149f433
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 119 deletions.
11 changes: 9 additions & 2 deletions cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=16384,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
Expand Down Expand Up @@ -79,6 +79,13 @@ def main():
parser.add_argument('--custom_module_path', type=str, required=False,
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')

parser.add_argument("--min_balance_quality", type=float, default=0.0,
help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) "
"goes below this threshold. Default: rebalancing is disabled")
parser.add_argument("--mean_balance_check_period", type=float, default=150,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")

parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')

Expand All @@ -104,7 +111,7 @@ def main():
use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token

server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)

try:
server.join()
Expand Down
107 changes: 96 additions & 11 deletions src/server/block_selection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,103 @@
from typing import List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
from hivemind import PeerID, get_logger

from src.data_structures import RemoteModuleInfo, ServerState

__all__ = ["choose_best_blocks", "should_choose_other_blocks"]

logger = get_logger(__file__)


@dataclass
class Span:
start: int
end: int
throughput: float

@property
def length(self):
return self.end - self.start

def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length


def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
throughputs = []
for module in remote_module_infos:
def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
if module is None:
throughputs.append(0)
continue
throughputs.append(
sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
)

options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
best_start = min(options)[1]
return list(range(best_start, best_start + num_blocks))
for peer_id, server in module.servers.items():
if server.state == ServerState.OFFLINE:
continue

if peer_id in spans:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)

throughputs[block] += server.throughput

return spans, throughputs


def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
options = (
(sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
for i in range(0, len(throughputs) - num_blocks + 1)
)
return min(options)[-1]


def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = _compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks, None)
return list(range(start, start + num_blocks))


def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
) -> bool:
spans, throughputs = _compute_spans(module_infos)
initial_throughput = throughputs.min()

assert local_peer_id in spans, "Span served by this server is not present in the DHT"
local_span = spans[local_peer_id]
throughputs[local_span.start : local_span.end] -= local_span.throughput

new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
if local_span.start == new_start:
return False # This server is on its best place already
local_span.move_to(new_start)

throughputs[local_span.start : local_span.end] += local_span.throughput

moved = True
while moved:
servers = list(spans.keys())
np.random.shuffle(servers)

moved = False
for peer_id in servers:
span = spans[peer_id]
throughputs[span.start : span.end] -= span.throughput

new_start = _choose_best_start(throughputs, span.length, span.start)
if span.start != new_start:
span.move_to(new_start)
moved = True

throughputs[span.start : span.end] += span.throughput

new_throughput = throughputs.min()
balance_quality = initial_throughput / new_throughput
logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")

eps = 1e-6
return balance_quality < min_balance_quality - eps
Loading

0 comments on commit 149f433

Please sign in to comment.