本文整理汇总了Python中tensorboard.SummaryWriter方法的典型用法代码示例。如果您正苦于以下问题:Python tensorboard.SummaryWriter方法的具体用法?Python tensorboard.SummaryWriter怎么用?Python tensorboard.SummaryWriter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorboard
的用法示例。
在下文中一共展示了tensorboard.SummaryWriter方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import tensorboard [as 别名]
# 或者: from tensorboard import SummaryWriter [as 别名]
def __init__(self, sargs, input_vocabs, label_vocab, *args,
val_data=None, **kwargs):
super(LSTMCRFTrainer, self).__init__(*args, **kwargs)
self.args = sargs
self.input_vocabs = input_vocabs
self.label_vocab = label_vocab
self.val_data = val_data
self.writer = None
if self.args.tensorboard:
self.writer = T.SummaryWriter(self.args.save_dir)
self.repeatables = {
self.args.ckpt_period: self.save_checkpoint
}
if self.args.val:
self.repeatables[self.args.val_period] = \
self.validate
示例2: __init__
# 需要导入模块: import tensorboard [as 别名]
# 或者: from tensorboard import SummaryWriter [as 别名]
def __init__(self, dist_logging_dir=None, scalar_logging_dir=None,
logfile_path=None, batch_size=None, iter_monitor=0,
frequent=None, prefix='ssd'):
self.scalar_logging_dir = scalar_logging_dir
self.dist_logging_dir = dist_logging_dir
self.logfile_path = logfile_path
self.batch_size = batch_size
self.iter_monitor = iter_monitor
self.frequent = frequent
self.prefix = prefix
self.batch = 0
self.line_idx = 0
try:
from tensorboard import SummaryWriter
self.dist_summary_writer = SummaryWriter(dist_logging_dir)
self.scalar_summary_writer = SummaryWriter(scalar_logging_dir)
except ImportError:
logging.error('You can install tensorboard via `pip install tensorboard`.')
示例3: create_writer
# 需要导入模块: import tensorboard [as 别名]
# 或者: from tensorboard import SummaryWriter [as 别名]
def create_writer():
return SummaryWriter(LOG_DIR)
示例4: __init__
# 需要导入模块: import tensorboard [as 别名]
# 或者: from tensorboard import SummaryWriter [as 别名]
def __init__(self, logging_dir, prefix=None):
self.prefix = prefix
try:
from tensorboard import SummaryWriter
self.summary_writer = SummaryWriter(logging_dir)
except ImportError:
logging.error('You can install tensorboard via `pip install tensorboard`.')
示例5: test_log_scalar_summary
# 需要导入模块: import tensorboard [as 别名]
# 或者: from tensorboard import SummaryWriter [as 别名]
def test_log_scalar_summary():
logdir = './experiment/scalar'
writer = SummaryWriter(logdir)
for i in range(10):
writer.add_scalar('test_scalar', i+1)
writer.close()
示例6: do_training
# 需要导入模块: import tensorboard [as 别名]
# 或者: from tensorboard import SummaryWriter [as 别名]
def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, decay):
summary_writer = SummaryWriter(args.tblog_dir)
lr_scheduler = SimpleLRScheduler(learning_rate)
optimizer_params = {'lr_scheduler': lr_scheduler}
module.init_params()
module.init_optimizer(kvstore=kvstore,
optimizer=optimizer,
optimizer_params=optimizer_params)
n_epoch = 0
while True:
if n_epoch >= num_epoch:
break
train_iter.reset()
val_iter.reset()
loss_metric.reset()
for n_batch, data_batch in enumerate(train_iter):
module.forward_backward(data_batch)
module.update()
module.update_metric(loss_metric, data_batch.label)
loss_metric.get_batch_log(n_batch)
train_acc, train_loss, train_recon_err = loss_metric.get_name_value()
loss_metric.reset()
for n_batch, data_batch in enumerate(val_iter):
module.forward(data_batch)
module.update_metric(loss_metric, data_batch.label)
loss_metric.get_batch_log(n_batch)
val_acc, val_loss, val_recon_err = loss_metric.get_name_value()
summary_writer.add_scalar('train_acc', train_acc, n_epoch)
summary_writer.add_scalar('train_loss', train_loss, n_epoch)
summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch)
summary_writer.add_scalar('val_acc', val_acc, n_epoch)
summary_writer.add_scalar('val_loss', val_loss, n_epoch)
summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch)
print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, train_acc, train_loss, train_recon_err))
print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, val_acc, val_loss, val_recon_err))
print('SAVE CHECKPOINT')
module.save_checkpoint(prefix=model_prefix, epoch=n_epoch)
n_epoch += 1
lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch)