Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebalance swarm when necessary #34

Merged
merged 12 commits into from
Oct 12, 2022
Prev Previous commit
Next Next commit
Set mean delays instead of max delays
  • Loading branch information
borzunov committed Oct 8, 2022
commit fdb0f7371bc33f0ec29e47dc3e2c42d7e5343afc
12 changes: 6 additions & 6 deletions src/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class Server(threading.Thread):

def __init__(
self,

prefix: Optional[str],
converted_model_name_or_path: str,
throughput: Union[float, str],
Expand All @@ -59,8 +58,8 @@ def __init__(
expiration: Optional[float] = None,
prefetch_batches: int = 1,
sender_threads: int = 1,
max_block_selection_delay: float = 1,
max_balance_check_period: float = 600,
mean_block_selection_delay: float = 0.5,
mean_balance_check_period: float = 300,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
*,
Expand Down Expand Up @@ -134,7 +133,8 @@ def __init__(
raise
block_indices = range(first_block_index, last_block_index)
self.block_indices, self.num_blocks = block_indices, num_blocks
self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
self.mean_balance_check_period = mean_balance_check_period

self.stop = threading.Event()
if start:
Expand Down Expand Up @@ -172,7 +172,7 @@ def run(self):
self.module_container.ready.wait()

while True:
timeout = random.random() * self.max_balance_check_period
timeout = random.random() * 2 * self.mean_balance_check_period
if self.stop.wait(timeout):
return
if self._should_choose_other_blocks():
Expand All @@ -186,7 +186,7 @@ def _choose_blocks(self) -> List[int]:

# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
time.sleep(random.random() * self.max_block_selection_delay)
time.sleep(random.random() * 2 * self.mean_block_selection_delay)

assert self.num_blocks is not None
uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
Expand Down