Skip to content

Commit

Permalink
bugfixed: solve the issue of blocking with dist_train.sh, automatical…
Browse files Browse the repository at this point in the history
…ly search tcp ports (open-mmlab#815)

* bugfixed: stuck when training with dist_train.sh, support tcp_port

* bugfixed: solve the issue of blocking with dist_train.sh, automatically search tcp ports
  • Loading branch information
sshaoshuai authored Feb 21, 2022
1 parent a5cf2a5 commit e1bfcec
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pcdet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def init_dist_slurm(tcp_port, local_rank, backend='nccl'):
def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
os.environ['MASTER_PORT'] = str(tcp_port)
os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = str(tcp_port)
# os.environ['MASTER_ADDR'] = 'localhost'
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(local_rank % num_gpus)

Expand Down
12 changes: 11 additions & 1 deletion tools/scripts/dist_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,15 @@ set -x
NGPUS=$1
PY_ARGS=${@:2}

python -m torch.distributed.launch --nproc_per_node=${NGPUS} train.py --launcher pytorch ${PY_ARGS}
while true
do
PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
if [ "${status}" != "0" ]; then
break;
fi
done
echo $PORT

python -m torch.distributed.launch --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} train.py --launcher pytorch ${PY_ARGS}

18 changes: 18 additions & 0 deletions tools/scripts/torch_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env bash

set -x
NGPUS=$1
PY_ARGS=${@:2}

while true
do
PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
if [ "${status}" != "0" ]; then
break;
fi
done
echo $PORT

torchrun --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} train.py --launcher pytorch ${PY_ARGS}

0 comments on commit e1bfcec

Please sign in to comment.