Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix routing through relay, default network RPS, --token, logging, readme #399

Merged
merged 10 commits into from
Jul 22, 2023
Prev Previous commit
Next Next commit
Apply relay penalty in max-throughput routing
  • Loading branch information
borzunov committed Jul 22, 2023
commit 1c1c44021aa2c6a6b4eb493c8eb939f598d842ac
12 changes: 10 additions & 2 deletions src/petals/client/routing/sequence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,23 @@ def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = No
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left

def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
def _make_sequence_with_max_throughput(
self, start_index: int, end_index: int, *, relay_penalty: float = 0.5
) -> List[RemoteSpanInfo]:
span_sequence = []
current_index = start_index
while current_index < end_index:
candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
if not candidate_spans:
raise MissingBlocksError(current_index)

span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
span_weights = np.array(
[
span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty)
for span in candidate_spans
],
dtype=np.float64,
)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())

assert chosen_span.start <= current_index < chosen_span.end
Expand Down