當前位置: 首頁>>代碼示例>>Python>>正文


Python utils.get_dist_info方法代碼示例

本文整理匯總了Python中mmcv.runner.utils.get_dist_info方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.get_dist_info方法的具體用法?Python utils.get_dist_info怎麽用?Python utils.get_dist_info使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在mmcv.runner.utils的用法示例。


在下文中一共展示了utils.get_dist_info方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: init_logger

# 需要導入模塊: from mmcv.runner import utils [as 別名]
# 或者: from mmcv.runner.utils import get_dist_info [as 別名]
def init_logger(log_dir=None, level=logging.INFO):
    """Init the logger.

    Args:
        log_dir(str, optional): Log file directory. If not specified, no
            log file will be used.
        level (int or str): See the built-in python logging module.

    Returns:
        :obj:`~logging.Logger`: Python logger.
    """
    rank, _ = get_dist_info()
    logging.basicConfig(
        format='%(asctime)s - %(message)s', level=level)
    logger = logging.getLogger(__name__)
    if log_dir and rank == 0:
        filename = '{}.log'.format(time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        log_file = osp.join(log_dir, filename)
        _add_file_handler(logger, log_file, level=level)
    return logger 
開發者ID:JaminFong,項目名稱:FNA,代碼行數:22,代碼來源:utils.py

示例2: get_root_logger

# 需要導入模塊: from mmcv.runner import utils [as 別名]
# 或者: from mmcv.runner.utils import get_dist_info [as 別名]
def get_root_logger(log_dir=None, log_level=logging.INFO):
    logger = logging.getLogger()
    if not logger.hasHandlers():
        logging.basicConfig(
            format='%(asctime)s - %(message)s',
            level=log_level,
            datefmt='%m/%d %I:%M:%S %p')
    rank, _ = get_dist_info()
    if rank != 0:
        logger.setLevel('ERROR')

    if log_dir and rank == 0:
        filename = '{}.log'.format(time.strftime('%Y%m%d_%H%M%S', time.localtime()))
        log_file = osp.join(log_dir, filename)
        _add_file_handler(logger, log_file, level=log_level)
    return logger 
開發者ID:JaminFong,項目名稱:FNA,代碼行數:18,代碼來源:utils.py

示例3: __init__

# 需要導入模塊: from mmcv.runner import utils [as 別名]
# 或者: from mmcv.runner.utils import get_dist_info [as 別名]
def __init__(self,
                 dataset,
                 samples_per_gpu=1,
                 num_replicas=None,
                 rank=None):
        _rank, _num_replicas = get_dist_info()
        if num_replicas is None:
            num_replicas = _num_replicas
        if rank is None:
            rank = _rank
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        assert hasattr(self.dataset, 'flag')
        self.flag = self.dataset.flag
        self.group_sizes = np.bincount(self.flag)

        self.num_samples = 0
        for i, j in enumerate(self.group_sizes):
            self.num_samples += int(
                math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
                          self.num_replicas)) * self.samples_per_gpu
        self.total_size = self.num_samples * self.num_replicas 
開發者ID:xvjiarui,項目名稱:GCNet,代碼行數:28,代碼來源:sampler.py

示例4: load_checkpoint

# 需要導入模塊: from mmcv.runner import utils [as 別名]
# 或者: from mmcv.runner.utils import get_dist_info [as 別名]
def load_checkpoint(filename,
                    model=None,
                    map_location=None,
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    if logger is None:
        logger = logging.getLogger()
    # load checkpoint from modelzoo or file or url
    logger.info('Start loading the model from ' + filename)
    if filename.startswith(('http://', 'https://')):
        url = filename
        filename = '../' + url.split('/')[-1]
        if get_dist_info()[0]==0:
            if osp.isfile(filename):
                os.system('rm '+filename)
            os.system('wget -N -q -P ../ ' + url)
        dist.barrier()
    elif filename.startswith(('hdfs://',)):
        url = filename
        filename = '../' + url.split('/')[-1]
        if get_dist_info()[0]==0:
            if osp.isfile(filename):
                os.system('rm '+filename)
            os.system('hdfs dfs -get ' + url + ' ../')
        dist.barrier()
    else:
        if not osp.isfile(filename):
            raise IOError('{} is not a checkpoint file'.format(filename))
    checkpoint = torch.load(filename, map_location=map_location)
    # get state_dict from checkpoint
    if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
        state_dict = checkpoint
    else:
        raise RuntimeError(
            'No state_dict found in checkpoint file {}'.format(filename))
    # strip prefix of state_dict
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    # load state_dict
    if model is not None:
        if hasattr(model, 'module'):
            model.module.load_state_dict(state_dict, strict=strict)
        else:
            model.load_state_dict(state_dict, strict=strict)
        logger.info('Loading the model finished!')
    return state_dict 
開發者ID:JaminFong,項目名稱:FNA,代碼行數:61,代碼來源:utils.py


注:本文中的mmcv.runner.utils.get_dist_info方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。