當前位置: 首頁>>代碼示例>>Python>>正文


Python utils.get_logger方法代碼示例

本文整理匯總了Python中lib.utils.get_logger方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.get_logger方法的具體用法?Python utils.get_logger怎麽用?Python utils.get_logger使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在lib.utils的用法示例。


在下文中一共展示了utils.get_logger方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: train_toy

# 需要導入模塊: from lib import utils [as 別名]
# 或者: from lib.utils import get_logger [as 別名]
def train_toy(toy, load=True, nb_steps=20, nb_flow=1, folder=""):
    device = "cpu"
    logger = utils.get_logger(logpath=os.path.join(folder, toy, 'logs'), filepath=os.path.abspath(__file__))

    logger.info("Creating model...")
    model = UMNNMAFFlow(nb_flow=nb_flow, nb_in=2, hidden_derivative=[100, 100, 100, 100], hidden_embedding=[100, 100, 100, 100],
                        embedding_s=10, nb_steps=nb_steps, device=device).to(device)
    logger.info("Model created.")
    opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)

    if load:
        logger.info("Loading model...")
        model.load_state_dict(torch.load(folder + toy+'/model.pt'))
        model.train()
        opt.load_state_dict(torch.load(folder + toy+'/ADAM.pt'))
        logger.info("Model loaded.")

    nb_samp = 100
    batch_size = 100

    x_test = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)
    x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=1000)).to(device)

    for epoch in range(10000):
        ll_tot = 0
        start = timer()
        for j in range(0, nb_samp, batch_size):
            cur_x = torch.tensor(toy_data.inf_train_gen(toy, batch_size=batch_size)).to(device)
            ll, z = model.compute_ll(cur_x)
            ll = -ll.mean()
            ll_tot += ll.detach()/(nb_samp/batch_size)
            loss = ll
            opt.zero_grad()
            loss.backward()
            opt.step()
        end = timer()
        ll_test, _ = model.compute_ll(x_test)
        ll_test = -ll_test.mean()
        logger.info("epoch: {:d} - Train loss: {:4f} - Test loss: {:4f} - Elapsed time per epoch {:4f} (seconds)".
                    format(epoch, ll_tot.item(), ll_test.item(), end-start))

        if (epoch % 100) == 0:
            summary_plots(x, x_test, folder, epoch, model, ll_tot, ll_test)
            torch.save(model.state_dict(), folder + toy + '/model.pt')
            torch.save(opt.state_dict(), folder + toy + '/ADAM.pt') 
開發者ID:AWehenkel,項目名稱:UMNN,代碼行數:47,代碼來源:ToyExperiments.py


注:本文中的lib.utils.get_logger方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。