本文整理汇总了Python中lib.utils.get_logger方法的典型用法代码示例。如果您正苦于以下问题:Python utils.get_logger方法的具体用法?Python utils.get_logger怎么用?Python utils.get_logger使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类lib.utils
的用法示例。
在下文中一共展示了utils.get_logger方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_toy
# 需要导入模块: from lib import utils [as 别名]
# 或者: from lib.utils import get_logger [as 别名]
def train_toy(toy, load=True, nb_steps=20, nb_flow=1, folder=""):
device = "cpu"
logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'), filepath=os.path.abspath(__file__))
logger.info("Creating model...")
model = UMNNMAFFlow(nb_flow=nb_flow, nb_in=2, hidden_derivative=[100, 100, 100, 100], hidden_embedding=[100, 100, 100, 100],
embedding_s=10, nb_steps=nb_steps, device=device).to(device)
logger.info("Model created.")
opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)
if load:
logger.info("Loading model...")
model.load_state_dict(torch.load(folder + toy+'/model.pt'))
model.train()
opt.load_state_dict(torch.load(folder + toy+'/ADAM.pt'))
logger.info("Model loaded.")
nb_samp = 100
batch_size = 100
x_test = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)
x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)
for epoch in range(10000):
ll_tot = 0
start = timer()
for j in range(0, nb_samp, batch_size):
cur_x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device)
ll, z = model.compute_ll(cur_x)
ll = -ll.mean()
ll_tot += ll.detach()/(nb_samp/batch_size)
loss = ll
opt.zero_grad()
loss.backward()
opt.step()
end = timer()
ll_test, _ = model.compute_ll(x_test)
ll_test = -ll_test.mean()
logger.info("epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - Elapsed time per epoch {:4f} (seconds)".
format(epoch, ll_tot.item(), ll_test.item(), end-start))
if (epoch % 100) == 0:
summary_plots(x, x_test, folder, epoch, model, ll_tot, ll_test)
torch.save(model.state_dict(), folder + toy + '/model.pt')
torch.save(opt.state_dict(), folder + toy + '/ADAM.pt')