本文整理汇总了Python中mmcv.runner.utils.get_dist_info方法的典型用法代码示例。如果您正苦于以下问题:Python utils.get_dist_info方法的具体用法?Python utils.get_dist_info怎么用?Python utils.get_dist_info使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mmcv.runner.utils
的用法示例。
在下文中一共展示了utils.get_dist_info方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: init_logger
# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def init_logger(log_dir=None, level=logging.INFO):
"""Init the logger.
Args:
log_dir(str, optional): Log file directory. If not specified, no
log file will be used.
level (int or str): See the built-in python logging module.
Returns:
:obj:`~logging.Logger`: Python logger.
"""
rank, _ = get_dist_info()
logging.basicConfig(
format='%(asctime)s - %(message)s', level=level)
logger = logging.getLogger(__name__)
if log_dir and rank == 0:
filename = '{}.log'.format(time.strftime('%Y%m%d_%H%M%S', time.localtime()))
log_file = osp.join(log_dir, filename)
_add_file_handler(logger, log_file, level=level)
return logger
示例2: get_root_logger
# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def get_root_logger(log_dir=None, log_level=logging.INFO):
logger = logging.getLogger()
if not logger.hasHandlers():
logging.basicConfig(
format='%(asctime)s - %(message)s',
level=log_level,
datefmt='%m/%d %I:%M:%S %p')
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
if log_dir and rank == 0:
filename = '{}.log'.format(time.strftime('%Y%m%d_%H%M%S', time.localtime()))
log_file = osp.join(log_dir, filename)
_add_file_handler(logger, log_file, level=log_level)
return logger
示例3: __init__
# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert hasattr(self.dataset, 'flag')
self.flag = self.dataset.flag
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, j in enumerate(self.group_sizes):
self.num_samples += int(
math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
self.num_replicas)) * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
示例4: load_checkpoint
# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def load_checkpoint(filename,
model=None,
map_location=None,
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
if logger is None:
logger = logging.getLogger()
# load checkpoint from modelzoo or file or url
logger.info('Start loading the model from ' + filename)
if filename.startswith(('http://', 'https://')):
url = filename
filename = '../' + url.split('/')[-1]
if get_dist_info()[0]==0:
if osp.isfile(filename):
os.system('rm '+filename)
os.system('wget -N -q -P ../ ' + url)
dist.barrier()
elif filename.startswith(('hdfs://',)):
url = filename
filename = '../' + url.split('/')[-1]
if get_dist_info()[0]==0:
if osp.isfile(filename):
os.system('rm '+filename)
os.system('hdfs dfs -get ' + url + ' ../')
dist.barrier()
else:
if not osp.isfile(filename):
raise IOError('{} is not a checkpoint file'.format(filename))
checkpoint = torch.load(filename, map_location=map_location)
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
state_dict = checkpoint
else:
raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(filename))
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# load state_dict
if model is not None:
if hasattr(model, 'module'):
model.module.load_state_dict(state_dict, strict=strict)
else:
model.load_state_dict(state_dict, strict=strict)
logger.info('Loading the model finished!')
return state_dict