forked from open-mmlab/OpenPCDet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support recording batch time during training (open-mmlab#697)
- Loading branch information
Showing
3 changed files
with
228 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
""" | ||
This file contains primitives for multi-gpu communication. | ||
This is useful when doing distributed training. | ||
deeply borrow from maskrcnn-benchmark and ST3D | ||
""" | ||
|
||
import pickle | ||
import time | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
def get_world_size(): | ||
if not dist.is_available(): | ||
return 1 | ||
if not dist.is_initialized(): | ||
return 1 | ||
return dist.get_world_size() | ||
|
||
|
||
def get_rank(): | ||
if not dist.is_available(): | ||
return 0 | ||
if not dist.is_initialized(): | ||
return 0 | ||
return dist.get_rank() | ||
|
||
|
||
def is_main_process(): | ||
return get_rank() == 0 | ||
|
||
|
||
def synchronize(): | ||
""" | ||
Helper function to synchronize (barrier) among all processes when | ||
using distributed training | ||
""" | ||
if not dist.is_available(): | ||
return | ||
if not dist.is_initialized(): | ||
return | ||
world_size = dist.get_world_size() | ||
if world_size == 1: | ||
return | ||
dist.barrier() | ||
|
||
|
||
def all_gather(data): | ||
""" | ||
Run all_gather on arbitrary picklable data (not necessarily tensors) | ||
Args: | ||
data: any picklable object | ||
Returns: | ||
list[data]: list of data gathered from each rank | ||
""" | ||
world_size = get_world_size() | ||
if world_size == 1: | ||
return [data] | ||
|
||
# serialized to a Tensor | ||
origin_size = None | ||
if not isinstance(data, torch.Tensor): | ||
buffer = pickle.dumps(data) | ||
storage = torch.ByteStorage.from_buffer(buffer) | ||
tensor = torch.ByteTensor(storage).to("cuda") | ||
else: | ||
origin_size = data.size() | ||
tensor = data.reshape(-1) | ||
|
||
tensor_type = tensor.dtype | ||
|
||
# obtain Tensor size of each rank | ||
local_size = torch.LongTensor([tensor.numel()]).to("cuda") | ||
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] | ||
dist.all_gather(size_list, local_size) | ||
size_list = [int(size.item()) for size in size_list] | ||
max_size = max(size_list) | ||
|
||
# receiving Tensor from all ranks | ||
# we pad the tensor because torch all_gather does not support | ||
# gathering tensors of different shapes | ||
tensor_list = [] | ||
for _ in size_list: | ||
tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type)) | ||
if local_size != max_size: | ||
padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type) | ||
tensor = torch.cat((tensor, padding), dim=0) | ||
dist.all_gather(tensor_list, tensor) | ||
|
||
data_list = [] | ||
for size, tensor in zip(size_list, tensor_list): | ||
if origin_size is None: | ||
buffer = tensor.cpu().numpy().tobytes()[:size] | ||
data_list.append(pickle.loads(buffer)) | ||
else: | ||
buffer = tensor[:size] | ||
data_list.append(buffer) | ||
|
||
if origin_size is not None: | ||
new_shape = [-1] + list(origin_size[1:]) | ||
resized_list = [] | ||
for data in data_list: | ||
# suppose the difference of tensor size exist in first dimension | ||
data = data.reshape(new_shape) | ||
resized_list.append(data) | ||
|
||
return resized_list | ||
else: | ||
return data_list | ||
|
||
|
||
def reduce_dict(input_dict, average=True): | ||
""" | ||
Args: | ||
input_dict (dict): all the values will be reduced | ||
average (bool): whether to do average or sum | ||
Reduce the values in the dictionary from all processes so that process with rank | ||
0 has the averaged results. Returns a dict with the same fields as | ||
input_dict, after reduction. | ||
""" | ||
world_size = get_world_size() | ||
if world_size < 2: | ||
return input_dict | ||
with torch.no_grad(): | ||
names = [] | ||
values = [] | ||
# sort the keys so that they are consistent across processes | ||
for k in sorted(input_dict.keys()): | ||
names.append(k) | ||
values.append(input_dict[k]) | ||
values = torch.stack(values, dim=0) | ||
dist.reduce(values, dst=0) | ||
if dist.get_rank() == 0 and average: | ||
# only main process gets accumulated, so only divide by | ||
# world_size in this case | ||
values /= world_size | ||
reduced_dict = {k: v for k, v in zip(names, values)} | ||
return reduced_dict | ||
|
||
|
||
def average_reduce_value(data): | ||
data_list = all_gather(data) | ||
return sum(data_list) / len(data_list) | ||
|
||
|
||
def all_reduce(data, op="sum", average=False): | ||
|
||
def op_map(op): | ||
op_dict = { | ||
"SUM": dist.ReduceOp.SUM, | ||
"MAX": dist.ReduceOp.MAX, | ||
"MIN": dist.ReduceOp.MIN, | ||
"PRODUCT": dist.ReduceOp.PRODUCT, | ||
} | ||
return op_dict[op] | ||
|
||
world_size = get_world_size() | ||
if world_size > 1: | ||
reduced_data = data.clone() | ||
dist.all_reduce(reduced_data, op=op_map(op.upper())) | ||
if average: | ||
assert op.upper() == 'SUM' | ||
return reduced_data / world_size | ||
else: | ||
return reduced_data | ||
return data | ||
|
||
|
||
@torch.no_grad() | ||
def concat_all_gather(tensor): | ||
""" | ||
Performs all_gather operation on the provided tensors. | ||
*** Warning ***: torch.distributed.all_gather has no gradient. | ||
""" | ||
tensors_gather = [torch.ones_like(tensor) | ||
for _ in range(torch.distributed.get_world_size())] | ||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) | ||
|
||
output = torch.cat(tensors_gather, dim=0) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters