Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make a server ping next servers #356

Merged
merged 8 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Ping next servers
  • Loading branch information
borzunov committed Jul 15, 2023
commit 4b5ca4a56ef988d0e3d63c208d954c0886b0b60e
1 change: 1 addition & 0 deletions src/petals/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ServerInfo:
quant_type: Optional[str] = None
using_relay: Optional[bool] = None
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None

def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)
Expand Down
22 changes: 21 additions & 1 deletion src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import random
import threading
import time
from functools import partial
from typing import Dict, List, Optional, Sequence, Union

import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, PeerID, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
Expand All @@ -30,6 +31,7 @@
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.ping import ping_parallel
from petals.utils.version import get_compatible_model_repo

logger = get_logger(__name__)
Expand Down Expand Up @@ -680,9 +682,17 @@ def __init__(
self.expiration = expiration
self.stop = threading.Event()

last_uid = max(module_uids, key=lambda uid: int(uid.split(UID_DELIMITER)[-1]))
dht_prefix, block_index = last_uid.split(UID_DELIMITER)
self.next_uid = f"{dht_prefix}{UID_DELIMITER}{int(block_index) + 1}"

def run(self) -> None:
while True:
self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
next_pings = self._ping_next_servers()
self.server_info.next_pings = {peer_id.to_base58(): rtt for peer_id, rtt in next_pings.items()}
logger.warning(f"Pings to servers with {self.next_uid}: {self.server_info.next_pings=}")

declare_active_modules(
self.dht,
self.module_uids,
Expand All @@ -692,6 +702,16 @@ def run(self) -> None:
if self.stop.wait(self.update_period):
break

def _ping_next_servers(self, max_servers: int = 10) -> Dict[PeerID, float]:
[module_info] = get_remote_module_infos(self.dht, [self.next_uid], latest=True)
if module_info is None:
return {}

next_servers = list(module_info.servers)
if len(next_servers) > max_servers:
next_servers = random.sample(next_servers, max_servers)
return self.dht.run_coroutine(partial(ping_parallel, next_servers))


class RuntimeWithDeduplicatedPools(Runtime):
"""A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
Expand Down