本文整理汇总了Python中torch.optim.Adam.state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python Adam.state_dict方法的具体用法?Python Adam.state_dict怎么用?Python Adam.state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.optim.Adam
的用法示例。
在下文中一共展示了Adam.state_dict方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _run
# 需要导入模块: from torch.optim import Adam [as 别名]
# 或者: from torch.optim.Adam import state_dict [as 别名]
#.........这里部分代码省略.........
transfer batch to Torch: 1 ms, (1508424955295, 1508424955296)
Generating one example time: 2~5 ms, (1508458881118, 1508458881122)
Generating one document time: 50~400 ms, (1508458881118, 1508458881122)
Generating one batch time: 650~700 ms, (1508458880690, 1508458881122)
After changing to torch.sampler
Generating one example time: 4~7 ms
Generating one batch time: 900~1200 ms
'''
model = DistributedMemory(
vec_dim,
num_docs=len(dataset),
num_words=vocabulary_size)
cost_func = NegativeSampling()
optimizer = Adam(params=model.parameters(), lr=lr)
logger = logging.getLogger('root')
if torch.cuda.is_available():
model.cuda()
logger.info("Running on GPU - CUDA")
else:
logger.info("Running on CPU")
logger.info("Dataset comprised of {:d} documents.".format(len(dataset)))
logger.info("Vocabulary size is {:d}.\n".format(vocabulary_size))
logger.info("Training started.")
best_loss = float_info.max
prev_model_file_path = ""
progbar = Progbar(num_batches, batch_size=batch_size, total_examples = number_examples)
for epoch_i in range(num_epochs):
epoch_start_time = time.time()
loss = []
for batch_i in range(num_batches):
start_time = current_milli_time()
batch = next(data_generator)
current_time = current_milli_time()
print('data-prepare time: %d ms' % (round(current_time - start_time)))
start_time = current_milli_time()
x = model.forward(
batch.context_ids,
batch.doc_ids,
batch.target_noise_ids)
x = cost_func.forward(x)
loss.append(x.data[0])
print('forward time: %d ms' % round(current_milli_time() - start_time))
start_time = current_milli_time()
model.zero_grad()
x.backward()
optimizer.step()
print('backward time: %d ms' % round(current_milli_time() - start_time))
progbar.update(epoch_i, batch_i, )
# _print_progress(epoch_i, batch_i, num_batches)
# end of epoch
loss = torch.mean(torch.FloatTensor(loss))
is_best_loss = loss < best_loss
best_loss = min(loss, best_loss)
progbar.update(epoch_i, batch_i, [('loss', loss), ('best_loss', best_loss)])
model_file_name = MODEL_NAME.format(
data_file_name[:-4],
model_ver,
vec_combine_method,
context_size,
num_noise_words,
vec_dim,
batch_size,
lr,
epoch_i + 1,
loss)
model_file_path = join(MODELS_DIR, model_file_name)
if not os.path.exists(MODELS_DIR):
os.makedirs(MODELS_DIR)
state = {
'epoch': epoch_i + 1,
'model_state_dict': model.state_dict(),
'best_loss': best_loss,
'optimizer_state_dict': optimizer.state_dict()
}
if save_all:
torch.save(state, model_file_path)
elif is_best_loss:
try:
remove(prev_model_file_path)
except FileNotFoundError:
pass
torch.save(state, model_file_path)
prev_model_file_path = model_file_path
epoch_total_time = round(time.time() - epoch_start_time)
logger.info(" ({:d}s) - loss: {:.4f}".format(epoch_total_time, loss))