Skip to content

Commit

Permalink
init upload
Browse files Browse the repository at this point in the history
  • Loading branch information
usr922 authored Sep 16, 2022
1 parent 6160e18 commit 37701c2
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mmseg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .collect_env import collect_env
from .logger import get_root_logger

__all__ = ['get_root_logger', 'collect_env']
40 changes: 40 additions & 0 deletions mmseg/utils/collect_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0
# Modifications: Add code archive generation

import os
import tarfile

from mmcv.utils import collect_env as collect_base_env
from mmcv.utils import get_git_hash

import mmseg


def collect_env():
"""Collect the information of the running environments."""
env_info = collect_base_env()
env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'

return env_info


def is_source_file(x):
if x.isdir() or x.name.endswith(('.py', '.sh', '.yml', '.json', '.txt')) \
and '.mim' not in x.name and 'jobs/' not in x.name:
# print(x.name)
return x
else:
return None


def gen_code_archive(out_dir, file='code.tar.gz'):
archive = os.path.join(out_dir, file)
os.makedirs(os.path.dirname(archive), exist_ok=True)
with tarfile.open(archive, mode='w:gz') as tar:
tar.add('.', filter=is_source_file)
return archive


if __name__ == '__main__':
for name, val in collect_env().items():
print('{}: {}'.format(name, val))
29 changes: 29 additions & 0 deletions mmseg/utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Obtained from: https://github.com/open-mmlab/mmsegmentation/tree/v0.16.0

import logging

from mmcv.utils import get_logger


def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmseg".
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""

logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)

return logger
39 changes: 39 additions & 0 deletions mmseg/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import contextlib

import numpy as np
import torch
import torch.nn.functional as F


@contextlib.contextmanager
def np_local_seed(seed):
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)


def downscale_label_ratio(gt,
scale_factor,
min_ratio,
n_classes,
ignore_index=255):
assert scale_factor > 1
bs, orig_c, orig_h, orig_w = gt.shape
assert orig_c == 1
trg_h, trg_w = orig_h // scale_factor, orig_w // scale_factor
ignore_substitute = n_classes

out = gt.clone() # otw. next line would modify original gt
out[out == ignore_index] = ignore_substitute
out = F.one_hot(
out.squeeze(1), num_classes=n_classes + 1).permute(0, 3, 1, 2)
assert list(out.shape) == [bs, n_classes + 1, orig_h, orig_w], out.shape
out = F.avg_pool2d(out.float(), kernel_size=scale_factor)
gt_ratio, out = torch.max(out, dim=1, keepdim=True)
out[out == ignore_substitute] = ignore_index
out[gt_ratio < min_ratio] = ignore_index
assert list(out.shape) == [bs, 1, trg_h, trg_w], out.shape
return out
18 changes: 18 additions & 0 deletions mmseg/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Open-MMLab. All rights reserved.

__version__ = '0.16.0'


def parse_version_info(version_str):
version_info = []
for x in version_str.split('.'):
if x.isdigit():
version_info.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
version_info.append(int(patch_version[0]))
version_info.append(f'rc{patch_version[1]}')
return tuple(version_info)


version_info = parse_version_info(__version__)

0 comments on commit 37701c2

Please sign in to comment.