本文整理汇总了Python中mxnet.lr_scheduler方法的典型用法代码示例。如果您正苦于以下问题:Python mxnet.lr_scheduler方法的具体用法?Python mxnet.lr_scheduler怎么用?Python mxnet.lr_scheduler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet
的用法示例。
在下文中一共展示了mxnet.lr_scheduler方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_lr_scheduler
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
num_example, batch_size, begin_epoch):
"""
Compute learning rate and refactor scheduler
Parameters:
---------
learning_rate : float
original learning rate
lr_refactor_step : comma separated str
epochs to change learning rate
lr_refactor_ratio : float
lr *= ratio at certain steps
num_example : int
number of training images, used to estimate the iterations given epochs
batch_size : int
training batch size
begin_epoch : int
starting epoch
Returns:
---------
(learning_rate, mx.lr_scheduler) as tuple
"""
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
if lr_refactor_ratio >= 1:
return (learning_rate, None)
else:
lr = learning_rate
epoch_size = num_example // batch_size
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
if lr != learning_rate:
logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
if not steps:
return (lr, None)
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
return (lr, lr_scheduler)
示例2: get_lr_scheduler
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
num_example, batch_size, begin_epoch):
"""
Compute learning rate and refactor scheduler
Parameters:
---------
learning_rate : float
original learning rate
lr_refactor_step : comma separated str
epochs to change learning rate
lr_refactor_ratio : float
lr *= ratio at certain steps
num_example : int
number of training images, used to estimate the iterations given epochs
batch_size : int
training batch size
begin_epoch : int
starting epoch
Returns:
---------
(learning_rate, mx.lr_scheduler) as tuple
"""
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
if lr_refactor_ratio >= 1:
return (learning_rate, None)
else:
lr = learning_rate
epoch_size = num_example // batch_size
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
if lr != learning_rate:
logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
if not steps:
return (lr, None)
lr_scheduler = BurnInMultiFactorScheduler(burn_in=1000, step=steps, factor=lr_refactor_ratio)
return (lr, lr_scheduler)
示例3: get_optimizer_params
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def get_optimizer_params(optimizer=None, learning_rate=None, momentum=None,
weight_decay=None, lr_scheduler=None, ctx=None, logger=None):
if optimizer.lower() == 'rmsprop':
opt = 'rmsprop'
logger.info('you chose RMSProp, decreasing lr by a factor of 10')
optimizer_params = {'learning_rate': learning_rate / 10.0,
'wd': weight_decay,
'lr_scheduler': lr_scheduler,
'clip_gradient': None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
elif optimizer.lower() == 'sgd':
opt = 'sgd'
optimizer_params = {'learning_rate': learning_rate,
'momentum': momentum,
'wd': weight_decay,
'lr_scheduler': lr_scheduler,
'clip_gradient': None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
elif optimizer.lower() == 'adadelta':
opt = 'adadelta'
optimizer_params = {}
elif optimizer.lower() == 'adam':
opt = 'adam'
optimizer_params = {'learning_rate': learning_rate,
'lr_scheduler': lr_scheduler,
'clip_gradient': None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0}
return opt, optimizer_params
示例4: __init__
# 需要导入模块: import mxnet [as 别名]
# 或者: from mxnet import lr_scheduler [as 别名]
def __init__(self, config, model, criterion, ctx, sample_input):
config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent),
config['trainer']['output_dir'])
config['name'] = config['name'] + '_' + model.model_name
self.save_dir = os.path.join(config['trainer']['output_dir'], config['name'])
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
self.alphabet = config['dataset']['alphabet']
if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '':
shutil.rmtree(self.save_dir, ignore_errors=True)
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
# 保存本次实验的alphabet 到模型保存的地方
save(list(self.alphabet), os.path.join(self.save_dir, 'dict.txt'))
self.global_step = 0
self.start_epoch = 0
self.config = config
self.model = model
self.criterion = criterion
# logger and tensorboard
self.tensorboard_enable = self.config['trainer']['tensorboard']
self.epochs = self.config['trainer']['epochs']
self.display_interval = self.config['trainer']['display_interval']
if self.tensorboard_enable:
from mxboard import SummaryWriter
self.writer = SummaryWriter(self.save_dir, verbose=False)
self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
self.logger.info(pformat(self.config))
self.logger.info(self.model)
# device set
self.ctx = ctx
mx.random.seed(2) # 设置随机种子
self.logger.info('train with mxnet: {} and device: {}'.format(mx.__version__, self.ctx))
self.metrics = {'val_acc': 0, 'train_loss': float('inf'), 'best_model': ''}
schedule = self._initialize('lr_scheduler', mx.lr_scheduler)
optimizer = self._initialize('optimizer', mx.optimizer, lr_scheduler=schedule)
self.trainer = gluon.Trainer(self.model.collect_params(), optimizer=optimizer)
if self.config['trainer']['resume_checkpoint'] != '':
self._laod_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True)
elif self.config['trainer']['finetune_checkpoint'] != '':
self._laod_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False)
if self.tensorboard_enable:
try:
# add graph
from mxnet.gluon import utils as gutils
self.model(sample_input)
self.writer.add_graph(model)
except:
self.logger.error(traceback.format_exc())
self.logger.warn('add graph to tensorboard failed')