Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix routing through relay, default network RPS, --token, logging, readme #399

Merged
merged 10 commits into from
Jul 22, 2023
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...

### Connect your GPU and increase Petals capacity

Petals is a community-run system — we rely on people sharing their GPUs. You can check out available servers on our [swarm monitor](https://health.petals.dev) and connect your GPU to help serving one of the models!

Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+):

```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install --upgrade petals
pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server enoch/llama-65b-hf --adapters timdettmers/guanaco-65b
```

Expand All @@ -55,6 +57,8 @@ This will host a part of LLaMA-65B with optional [Guanaco](https://huggingface.c

💬 See [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues or feedback, ping us in [our Discord](https://discord.gg/D9MwApKgWa)!

🏆 If you host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks! You can specify them with `--public_name YOUR_NAME`. We will show them once your server loads all blocks.

### Check out tutorials, examples, and more

Basic tutorials:
Expand Down Expand Up @@ -97,7 +101,7 @@ Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/d

```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install --upgrade petals
pip install git+https://github.com/bigscience-workshop/petals
```

If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
Expand Down
12 changes: 10 additions & 2 deletions src/petals/client/routing/sequence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,23 @@ def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = No
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left

def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
def _make_sequence_with_max_throughput(
self, start_index: int, end_index: int, *, relay_penalty: float = 0.5
) -> List[RemoteSpanInfo]:
span_sequence = []
current_index = start_index
while current_index < end_index:
candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
if not candidate_spans:
raise MissingBlocksError(current_index)

span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
span_weights = np.array(
[
span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty)
for span in candidate_spans
],
dtype=np.float64,
)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())

assert chosen_span.start <= current_index < chosen_span.end
Expand Down
2 changes: 1 addition & 1 deletion src/petals/server/from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_pretrained_block(
max_disk_space: Optional[int] = None,
) -> nn.Module:
if config is None:
config = AutoDistributedConfig.from_pretrained(model_name, token=token)
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

Expand Down
2 changes: 1 addition & 1 deletion src/petals/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ async def _iterate_inference_steps(
anext_task.cancel()
get_push_task.cancel()
return
except:
except Exception:
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
raise

Expand Down
4 changes: 2 additions & 2 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(

self.block_config = AutoDistributedConfig.from_pretrained(
converted_model_name_or_path,
token=token,
use_auth_token=token,
revision=revision,
)

Expand All @@ -117,7 +117,7 @@ def __init__(
self.dht_prefix = dht_prefix

if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
expiration = max(3 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.expiration = expiration

self.request_timeout = request_timeout
Expand Down
65 changes: 32 additions & 33 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_server_throughput(
throughput = throughput_info["forward_rps"] / average_blocks_used
throughput = min(throughput, throughput_info.get("network_rps", math.inf))
throughput_info["throughput"] = throughput
logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")

return throughput_info

Expand All @@ -109,13 +109,10 @@ def measure_throughput_info(
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
) -> Dict[str, float]:
"""Measure network and compute throughput in forward pass tokens per second"""

logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
)

throughput_info = {
return {
"inference_rps": measure_compute_rps(
config,
device,
Expand All @@ -136,37 +133,39 @@ def measure_throughput_info(
n_steps=10,
inference=False,
),
"network_rps": measure_network_rps(config),
}
try:
throughput_info["network_rps"] = measure_network_rps(config)
except Exception as e:
logger.info(f"Network throughput is not available: {e}")
return throughput_info


def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()

if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")

def measure_network_rps(
config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 25e6
) -> Optional[float]:
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise RuntimeError("speedtest has returned network_rps == 0")

logger.info(
f"Network throughput: {network_rps:.1f} RPS "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
try:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()

if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()
if "exception" in network_info:
raise RuntimeError(f"speedtest failed: {network_info['exception']}")

network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise RuntimeError("speedtest has returned network_rps == 0")

logger.info(
f"Network throughput: {network_rps:.1f} tokens/sec "
f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
)
return network_rps
except RuntimeError as e:
logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s")
return default_speed / bits_per_request


def _measure_bits_per_second(pipe_send: mp.Pipe):
Expand Down Expand Up @@ -215,7 +214,7 @@ def measure_compute_rps(
devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())

logger.info(
f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block "
f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block "
f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
)
return device_rps
Expand Down
8 changes: 6 additions & 2 deletions src/petals/utils/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ class _AutoDistributedBase:

@classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs:
kwargs["token"] = True
if (
always_needs_auth(model_name_or_path)
and kwargs.get("token") is None
and kwargs.get("use_auth_token") is None
):
kwargs["use_auth_token"] = True

config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
if config.model_type not in _CLASS_MAPPING:
Expand Down
25 changes: 0 additions & 25 deletions tests/scripts/remove_old_models.py

This file was deleted.

Loading