当前位置: 首页>>代码示例>>Python>>正文


Python mxnet.lr_scheduler方法代码示例

本文整理汇总了Python中mxnet.lr_scheduler方法的典型用法代码示例。如果您正苦于以下问题:Python mxnet.lr_scheduler方法的具体用法?Python mxnet.lr_scheduler怎么用?Python mxnet.lr_scheduler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet的用法示例。


在下文中一共展示了mxnet.lr_scheduler方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_lr_scheduler

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
                     num_example, batch_size, begin_epoch):
    """
    Compute learning rate and refactor scheduler

    Parameters:
    ---------
    learning_rate : float
        original learning rate
    lr_refactor_step : comma separated str
        epochs to change learning rate
    lr_refactor_ratio : float
        lr *= ratio at certain steps
    num_example : int
        number of training images, used to estimate the iterations given epochs
    batch_size : int
        training batch size
    begin_epoch : int
        starting epoch

    Returns:
    ---------
    (learning_rate, mx.lr_scheduler) as tuple
    """
    assert lr_refactor_ratio > 0
    iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
    if lr_refactor_ratio >= 1:
        return (learning_rate, None)
    else:
        lr = learning_rate
        epoch_size = num_example // batch_size
        for s in iter_refactor:
            if begin_epoch >= s:
                lr *= lr_refactor_ratio
        if lr != learning_rate:
            logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
        steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
        if not steps:
            return (lr, None)
        lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
        return (lr, lr_scheduler) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:43,代码来源:train_net.py

示例2: get_lr_scheduler

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
                     num_example, batch_size, begin_epoch):
    """
    Compute learning rate and refactor scheduler

    Parameters:
    ---------
    learning_rate : float
        original learning rate
    lr_refactor_step : comma separated str
        epochs to change learning rate
    lr_refactor_ratio : float
        lr *= ratio at certain steps
    num_example : int
        number of training images, used to estimate the iterations given epochs
    batch_size : int
        training batch size
    begin_epoch : int
        starting epoch

    Returns:
    ---------
    (learning_rate, mx.lr_scheduler) as tuple
    """
    assert lr_refactor_ratio > 0
    iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
    if lr_refactor_ratio >= 1:
        return (learning_rate, None)
    else:
        lr = learning_rate
        epoch_size = num_example // batch_size
        for s in iter_refactor:
            if begin_epoch >= s:
                lr *= lr_refactor_ratio
        if lr != learning_rate:
            logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
        steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
        if not steps:
            return (lr, None)
        lr_scheduler = BurnInMultiFactorScheduler(burn_in=1000, step=steps, factor=lr_refactor_ratio)
        return (lr, lr_scheduler) 
开发者ID:zhreshold,项目名称:mxnet-yolo,代码行数:43,代码来源:train_net.py

示例3: get_optimizer_params

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def get_optimizer_params(optimizer=None, learning_rate=None, momentum=None,
                         weight_decay=None, lr_scheduler=None, ctx=None, logger=None):
    if optimizer.lower() == 'rmsprop':
        opt = 'rmsprop'
        logger.info('you chose RMSProp, decreasing lr by a factor of 10')
        optimizer_params = {'learning_rate': learning_rate / 10.0,
                            'wd': weight_decay,
                            'lr_scheduler': lr_scheduler,
                            'clip_gradient': None,
                            'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
    elif optimizer.lower() == 'sgd':
        opt = 'sgd'
        optimizer_params = {'learning_rate': learning_rate,
                            'momentum': momentum,
                            'wd': weight_decay,
                            'lr_scheduler': lr_scheduler,
                            'clip_gradient': None,
                            'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
    elif optimizer.lower() == 'adadelta':
        opt = 'adadelta'
        optimizer_params = {}
    elif optimizer.lower() == 'adam':
        opt = 'adam'
        optimizer_params = {'learning_rate': learning_rate,
                            'lr_scheduler': lr_scheduler,
                            'clip_gradient': None,
                            'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
    return opt, optimizer_params 
开发者ID:zhreshold,项目名称:mxnet-ssd,代码行数:30,代码来源:train_net.py

示例4: __init__

# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def __init__(self, config, model, criterion, ctx, sample_input):
        config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent),
                                                       config['trainer']['output_dir'])
        config['name'] = config['name'] + '_' + model.model_name
        self.save_dir = os.path.join(config['trainer']['output_dir'], config['name'])
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
        self.alphabet = config['dataset']['alphabet']

        if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '':
            shutil.rmtree(self.save_dir, ignore_errors=True)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        # 保存本次实验的alphabet 到模型保存的地方
        save(list(self.alphabet), os.path.join(self.save_dir, 'dict.txt'))
        self.global_step = 0
        self.start_epoch = 0
        self.config = config

        self.model = model
        self.criterion = criterion
        # logger and tensorboard
        self.tensorboard_enable = self.config['trainer']['tensorboard']
        self.epochs = self.config['trainer']['epochs']
        self.display_interval = self.config['trainer']['display_interval']
        if self.tensorboard_enable:
            from mxboard import SummaryWriter
            self.writer = SummaryWriter(self.save_dir, verbose=False)

        self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
        self.logger.info(pformat(self.config))
        self.logger.info(self.model)
        # device set
        self.ctx = ctx
        mx.random.seed(2)  # 设置随机种子

        self.logger.info('train with mxnet: {} and device: {}'.format(mx.__version__, self.ctx))
        self.metrics = {'val_acc': 0, 'train_loss': float('inf'), 'best_model': ''}

        schedule = self._initialize('lr_scheduler', mx.lr_scheduler)
        optimizer = self._initialize('optimizer', mx.optimizer, lr_scheduler=schedule)
        self.trainer = gluon.Trainer(self.model.collect_params(), optimizer=optimizer)

        if self.config['trainer']['resume_checkpoint'] != '':
            self._laod_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True)
        elif self.config['trainer']['finetune_checkpoint'] != '':
            self._laod_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False)

        if self.tensorboard_enable:
            try:
                # add graph
                from mxnet.gluon import utils as gutils
                self.model(sample_input)
                self.writer.add_graph(model)
            except:
                self.logger.error(traceback.format_exc())
                self.logger.warn('add graph to tensorboard failed') 
开发者ID:WenmuZhou,项目名称:crnn.gluon,代码行数:58,代码来源:base_trainer.py


注:本文中的mxnet.lr_scheduler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。