Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Llama, rebalancing, throughput eval, and all CLI scripts #452

Merged
merged 32 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
69abacc
Show argparse defaults, fix docstring
borzunov Aug 8, 2023
ca2850e
Test petals.cli.run_dht
borzunov Aug 8, 2023
816401e
Increase mean_block_selection_delay
borzunov Aug 8, 2023
7330653
Test rebalancing
borzunov Aug 8, 2023
a00e79d
Add help to benchmark argparse
borzunov Aug 8, 2023
5b3d4c4
Use less RAM
borzunov Aug 8, 2023
2b765b9
Don't set default model in benchmarks
borzunov Aug 8, 2023
fae58d9
Fix sleep time
borzunov Aug 8, 2023
856f53f
Test --throughput eval
borzunov Aug 8, 2023
05dc383
Fix flapping test
borzunov Aug 8, 2023
18e5b00
Use AutoDistributed{Config,Model} in tests
borzunov Aug 8, 2023
168e478
Add Maykeye/TinyLLama-v0 to tests
borzunov Aug 8, 2023
5760b15
Test using includes only
borzunov Aug 8, 2023
015238a
Adjust --num_blocks and --block_indices for 8-layer TinyLlama-v0
borzunov Aug 8, 2023
17cae64
Refactor matrix
borzunov Aug 8, 2023
b7b7464
Fix commands
borzunov Aug 8, 2023
c907990
Skip TP tests for llama
borzunov Aug 8, 2023
0040539
Fix test_greedy_generation() for llama
borzunov Aug 8, 2023
a5a95c4
Fix commands
borzunov Aug 8, 2023
c3e7638
Fix test_server_info()
borzunov Aug 8, 2023
b622a14
Fix server layout
borzunov Aug 8, 2023
8a379aa
Try reducing RAM usage
borzunov Aug 8, 2023
ecd7d3f
Check if benchmarks work
borzunov Aug 8, 2023
6ffbc28
Watch free RAM (common issue in CI)
borzunov Aug 8, 2023
033a3ca
Reduce RAM further
borzunov Aug 8, 2023
f06cebd
Tune constants to save RAM
borzunov Aug 8, 2023
47d2d53
Speed benchmark tests
borzunov Aug 8, 2023
d8e08e6
Fix flapping test
borzunov Aug 8, 2023
315c5c6
Try --no_relay
borzunov Aug 8, 2023
5cbb33b
Increase swap space
borzunov Aug 8, 2023
54cd213
Fix flapping test
borzunov Aug 8, 2023
1e34dfd
Fix flapping test
borzunov Aug 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 68 additions & 25 deletions .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
include:
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
fail-fast: false
timeout-minutes: 15
steps:
- name: Increase swap space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
Expand All @@ -31,44 +41,77 @@ jobs:
pip install .[dev]
- name: Test
run: |
export MODEL_NAME=bigscience/bloom-560m
export REF_NAME=bigscience/bloom-560m
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft

python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--adapters $ADAPTER_NAME &> server1.log &
SERVER1_PID=$!
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"

# [Step 1] Watch free RAM (lack of RAM is a common issue in CI)

bash -c 'while true; do free -h && sleep 30s; done' &
RAM_WATCH_PID=$!

sleep 5 # wait for the first server to initialize DHT
# [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)

python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
BOOTSTRAP_PID=$!

export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs

python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
SERVER2_PID=$!
sleep 5 # wait for DHT init

python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there

sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
sleep 10 # wait for the 1st server to choose blocks

python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
--identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$!

python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
SERVER3_PID=$!
# ^-- chunking test

python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)

tail -n 100 -f server*.log &
sleep 5 # wait for the log files to appear

tail -n 100 -f bootstrap.log server*.log &
LOGGER_PID=$!
sleep 30 # wait for servers to download layers

kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init

# [Step 3] Run PyTest

pytest tests --durations=0 --durations-min=1.0 -v

kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests
# [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)

python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --n_steps 1
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm

# [Step 5] Clean up

kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests

kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
echo "Done!"
18 changes: 9 additions & 9 deletions benchmarks/benchmark_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--n_steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()

if args.n_processes == "n_gpus":
Expand Down
14 changes: 7 additions & 7 deletions benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()

if args.n_processes == "n_gpus":
Expand Down
24 changes: 12 additions & 12 deletions benchmarks/benchmark_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--task", type=str, default="cls")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--pre_seq_len", type=int, default=16)
parser.add_argument("--n_steps", type=int, default=10)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
parser.add_argument("--task", type=str, default="cls", help="Training task type")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()

assert args.task in ["cls", "causal_lm"]
Expand Down
8 changes: 5 additions & 3 deletions src/petals/cli/run_dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
This may be eventually merged to the hivemind upstream.
"""

import argparse
import time
from argparse import ArgumentParser
from secrets import token_hex

from hivemind.dht import DHT, DHTNode
Expand All @@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode):


def main():
parser = ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--initial_peers",
nargs="*",
Expand Down Expand Up @@ -73,7 +73,9 @@ def main():
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
"--use_auto_relay",
action="store_true",
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
Expand Down
2 changes: 1 addition & 1 deletion src/petals/cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main():
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")

parser.add_argument("--adapters", nargs='+', default=(),
parser.add_argument("--adapters", nargs='*', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training")

# fmt:on
Expand Down
6 changes: 4 additions & 2 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
sender_threads: int = 1,
balance_quality: float = 0.75,
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
mean_block_selection_delay: float = 5,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous delay was not enough - servers were often choosing the same blocks since they didn't have time to write DHT "JOINING" messages.

token: Optional[Union[str, bool]] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
Expand Down Expand Up @@ -201,6 +201,8 @@ def __init__(
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
if num_blocks is not None:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None:
try:
first_block_index, last_block_index = block_indices.split(":")
Expand Down Expand Up @@ -295,7 +297,7 @@ def _choose_num_blocks(self) -> int:

num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
logger.info(
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
)
return num_blocks
Expand Down
File renamed without changes.
Binary file added tests/server2.id
Binary file not shown.
3 changes: 3 additions & 0 deletions tests/test_aux_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def test_bnb_not_imported_when_unnecessary():
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
if tensor_parallel and config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")

tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps(
config,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_block_exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import pytest
import torch

from petals import DistributedBloomConfig, RemoteSequential
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *


@pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)

for block_index in random.sample(range(config.num_hidden_layers), 3):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_chained_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import pytest
import torch

from petals import DistributedBloomConfig
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *


@pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
assert isinstance(remote_blocks, RemoteSequential)

Expand Down Expand Up @@ -43,7 +43,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq

@pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=5)

inputs = torch.randn(1, 8, config.hidden_size)
Expand Down
Loading
Loading