Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebalance swarm when necessary #34

Merged
merged 12 commits into from
Oct 12, 2022
4 changes: 2 additions & 2 deletions cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=16384,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
Expand Down Expand Up @@ -104,7 +104,7 @@ def main():
use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token

server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
borzunov marked this conversation as resolved.
Show resolved Hide resolved

try:
server.join()
Expand Down
105 changes: 94 additions & 11 deletions src/server/block_selection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,101 @@
from typing import List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
from hivemind import PeerID, get_logger

from src.data_structures import RemoteModuleInfo, ServerState

logger = get_logger(__file__)


@dataclass
class Span:
borzunov marked this conversation as resolved.
Show resolved Hide resolved
start: int
end: int
throughput: float

@property
def length(self):
return self.end - self.start

def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length


def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
throughputs = []
for module in remote_module_infos:
def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
if module is None:
throughputs.append(0)
continue
throughputs.append(
sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
)

options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
best_start = min(options)[1]
return list(range(best_start, best_start + num_blocks))
for peer_id, server in module.servers.items():
if server.state == ServerState.OFFLINE:
continue

if peer_id in spans:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)

throughputs[block] += server.throughput

return spans, throughputs


def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
options = (
(sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
for i in range(0, len(throughputs) - num_blocks + 1)
)
return min(options)[-1]


def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = _compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks, None)
return list(range(start, start + num_blocks))


def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
) -> bool:
spans, throughputs = _compute_spans(module_infos)
initial_throughput = throughputs.min()

assert local_peer_id in spans, "Span served by this server is not present in the DHT"
local_span = spans[local_peer_id]
throughputs[local_span.start : local_span.end] -= local_span.throughput

new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
if local_span.start == new_start:
return False # This server is on its best place already
local_span.move_to(new_start)

throughputs[local_span.start : local_span.end] += local_span.throughput

moved = True
while moved:
servers = list(spans.keys())
np.random.shuffle(servers)

moved = False
for peer_id in servers:
span = spans[peer_id]
throughputs[span.start : span.end] -= span.throughput

new_start = _choose_best_start(throughputs, span.length, span.start)
if span.start != new_start:
span.move_to(new_start)
moved = True

throughputs[span.start : span.end] += span.throughput

new_throughput = throughputs.min()
balance_quality = initial_throughput / new_throughput
logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")

eps = 1e-6
return balance_quality < min_balance_quality - eps
Loading