Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
add missing update for utils
Browse files Browse the repository at this point in the history
  • Loading branch information
liuqiuhui2015 committed Jun 12, 2020
1 parent ce580f3 commit 2b6b220
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
45 changes: 41 additions & 4 deletions utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,27 @@

from threading import Thread

from functools import wraps

from random import sample
from random import seed as rpyseed

from math import ceil

import logging

from utils.h5serial import h5save, h5load

mask_tensor_type = torch.uint8 if torch.__version__ < "1.2.0" else torch.bool
secure_type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64}

# handling torch.bool
if torch.__version__ < "1.2.0":
mask_tensor_type = torch.uint8
nccl_type_map = None
else:
mask_tensor_type = torch.bool
secure_type_map[mask_tensor_type] = torch.int64
nccl_type_map = {torch.bool:torch.uint8}

def pad_tensors(tensor_list, dim=-1):

Expand Down Expand Up @@ -256,14 +269,15 @@ def ModuleList2Dict(modin):

def add_module(m, strin, m_add):

if strin.find(".") < 0:
_name_list = strin.split(".")
if len(_name_list) == 1:
m.add_module(strin, m_add)
else:
_m, _name_list = m, strin.split(".")
_m = m
# update _modules with pytorch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module
for _tmp in _name_list[:-1]:
_m = _m._modules[_tmp]
_m._modules[_name_list[-1]] = m_add
_m.add_module(_name_list[-1], m_add)

return m

Expand Down Expand Up @@ -324,3 +338,26 @@ def report_parameters(modin):
rs += _para.numel()

return rs

def float2odd(fin):

_rs = ceil(fin)
if _rs % 2 == 1:
_rs -= 1

return _rs

def wrap_float2odd(func):
@wraps(func)
def wrapper(*args, **kwargs):
return float2odd(func(*args, **kwargs))
return wrapper

def iternext(iterin):

try:
rs = next(iterin)
except:
rs = None

return rs
16 changes: 16 additions & 0 deletions utils/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#encoding: utf-8

import torch.cuda.comm as comm
from utils.base import nccl_type_map

def secure_broadcast_coalesced(tensors, devices, buffer_size=10485760):

if nccl_type_map is None:

return comm.broadcast_coalesced(tensors, devices, buffer_size=buffer_size)
else:
src_type = [para.dtype for para in tensors]
map_type = [nccl_type_map[para.dtype] if para.dtype in nccl_type_map else None for para in tensors]
rs = comm.broadcast_coalesced([para if typ is None else para.to(typ) for para, typ in zip(tensors, map_type)], devices, buffer_size=buffer_size)

return list(zip(*[para if mtyp is None else [pu.to(styp) for pu in para] for para, mtyp, styp in zip(list(zip(*rs)), map_type, src_type)]))

0 comments on commit 2b6b220

Please sign in to comment.