当前位置: 首页>>代码示例>>Python>>正文


Python utils.get_dist_info方法代码示例

本文整理汇总了Python中mmcv.runner.utils.get_dist_info方法的典型用法代码示例。如果您正苦于以下问题:Python utils.get_dist_info方法的具体用法?Python utils.get_dist_info怎么用?Python utils.get_dist_info使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mmcv.runner.utils的用法示例。


在下文中一共展示了utils.get_dist_info方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: init_logger

# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def init_logger(log_dir=None, level=logging.INFO):
    """Init the logger.

    Args:
        log_dir(str, optional): Log file directory. If not specified, no
            log file will be used.
        level (int or str): See the built-in python logging module.

    Returns:
        :obj:`~logging.Logger`: Python logger.
    """
    rank, _ = get_dist_info()
    logging.basicConfig(
        format='%(asctime)s - %(message)s', level=level)
    logger = logging.getLogger(__name__)
    if log_dir and rank == 0:
        filename = '{}.log'.format(time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        log_file = osp.join(log_dir, filename)
        _add_file_handler(logger, log_file, level=level)
    return logger 
开发者ID:JaminFong,项目名称:FNA,代码行数:22,代码来源:utils.py

示例2: get_root_logger

# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def get_root_logger(log_dir=None, log_level=logging.INFO):
    logger = logging.getLogger()
    if not logger.hasHandlers():
        logging.basicConfig(
            format='%(asctime)s - %(message)s',
            level=log_level,
            datefmt='%m/%d %I:%M:%S %p')
    rank, _ = get_dist_info()
    if rank != 0:
        logger.setLevel('ERROR')

    if log_dir and rank == 0:
        filename = '{}.log'.format(time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        log_file = osp.join(log_dir, filename)
        _add_file_handler(logger, log_file, level=log_level)
    return logger 
开发者ID:JaminFong,项目名称:FNA,代码行数:18,代码来源:utils.py

示例3: __init__

# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.num_samples = 0
        for i, j in enumerate(self.group_sizes):
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        self.total_size = self.num_samples * self.num_replicas 
开发者ID:xvjiarui,项目名称:GCNet,代码行数:28,代码来源:sampler.py

示例4: load_checkpoint

# 需要导入模块: from mmcv.runner import utils [as 别名]
# 或者: from mmcv.runner.utils import get_dist_info [as 别名]
def load_checkpoint(filename,
                    model=None,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    if logger is None:
        logger = logging.getLogger()
    # load checkpoint from modelzoo or file or url
    logger.info('Start loading the model from ' + filename)
    if filename.startswith(('http://', 'https://')):
        url = filename
        filename = '../' + url.split('/')[-1]
        if get_dist_info()[0]==0:
            if osp.isfile(filename):
                os.system('rm '+filename)
            os.system('wget -N -q -P ../ ' + url)
        dist.barrier()
    elif filename.startswith(('hdfs://',)):
        url = filename
        filename = '../' + url.split('/')[-1]
        if get_dist_info()[0]==0:
            if osp.isfile(filename):
                os.system('rm '+filename)
            os.system('hdfs dfs -get ' + url + ' ../')
        dist.barrier()
    else:
        if not osp.isfile(filename):
            raise IOError('{} is not a checkpoint file'.format(filename))
    checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
        state_dict = checkpoint
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    # load state_dict
    if model is not None:
        if hasattr(model, 'module'):
            model.module.load_state_dict(state_dict, strict=strict)
        else:
            model.load_state_dict(state_dict, strict=strict)
        logger.info('Loading the model finished!')
    return state_dict 
开发者ID:JaminFong,项目名称:FNA,代码行数:61,代码来源:utils.py


注:本文中的mmcv.runner.utils.get_dist_info方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。