Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix routing through relay, default network RPS, --token, logging, readme #399

Merged
merged 10 commits into from
Jul 22, 2023
Prev Previous commit
Next Next commit
Use default network speed 25 Mbit/s
  • Loading branch information
borzunov committed Jul 22, 2023
commit 09d3d180bff7a0cac307bbca8edd13039be1d837
61 changes: 30 additions & 31 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,10 @@ def measure_throughput_info(
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
) -> Dict[str, float]:
"""Measure network and compute throughput in forward pass tokens per second"""

logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
)

throughput_info = {
return {
"inference_rps": measure_compute_rps(
config,
device,
Expand All @@ -136,37 +133,39 @@ def measure_throughput_info(
n_steps=10,
inference=False,
),
"network_rps": measure_network_rps(config),
}
try:
throughput_info["network_rps"] = measure_network_rps(config)
except Exception as e:
logger.info(f"Network throughput is not available: {e}")
return throughput_info


def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()

if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")

def measure_network_rps(
config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 25e6
) -> Optional[float]:
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise RuntimeError("speedtest has returned network_rps == 0")

logger.info(
f"Network throughput: {network_rps:.1f} RPS "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
try:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()

if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")

network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise RuntimeError("speedtest has returned network_rps == 0")

logger.info(
f"Network throughput: {network_rps:.1f} RPS "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
except RuntimeError as e:
logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s")
return default_speed / bits_per_request


def _measure_bits_per_second(pipe_send: mp.Pipe):
Expand Down
Loading