Skip to content

Commit

Permalink
Set mean delays instead of max delays
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Oct 8, 2022
1 parent 74c1b11 commit fdb0f73
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class Server(threading.Thread):

def __init__(
self,

prefix: Optional[str],
converted_model_name_or_path: str,
throughput: Union[float, str],
Expand All @@ -59,8 +58,8 @@ def __init__(
expiration: Optional[float] = None,
prefetch_batches: int = 1,
sender_threads: int = 1,
max_block_selection_delay: float = 1,
max_balance_check_period: float = 600,
mean_block_selection_delay: float = 0.5,
mean_balance_check_period: float = 300,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
*,
Expand Down Expand Up @@ -134,7 +133,8 @@ def __init__(
raise
block_indices = range(first_block_index, last_block_index)
self.block_indices, self.num_blocks = block_indices, num_blocks
self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
self.mean_balance_check_period = mean_balance_check_period

self.stop = threading.Event()
if start:
Expand Down Expand Up @@ -172,7 +172,7 @@ def run(self):
self.module_container.ready.wait()

while True:
timeout = random.random() * self.max_balance_check_period
timeout = random.random() * 2 * self.mean_balance_check_period
if self.stop.wait(timeout):
return
if self._should_choose_other_blocks():
Expand All @@ -186,7 +186,7 @@ def _choose_blocks(self) -> List[int]:

# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
time.sleep(random.random() * self.max_block_selection_delay)
time.sleep(random.random() * 2 * self.mean_block_selection_delay)

assert self.num_blocks is not None
uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
Expand Down

0 comments on commit fdb0f73

Please sign in to comment.