當前位置: 首頁>>代碼示例>>Python>>正文


Python Net.eval方法代碼示例

本文整理匯總了Python中net.Net.eval方法的典型用法代碼示例。如果您正苦於以下問題:Python Net.eval方法的具體用法?Python Net.eval怎麽用?Python Net.eval使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在net.Net的用法示例。


在下文中一共展示了Net.eval方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: Solver

# 需要導入模塊: from net import Net [as 別名]
# 或者: from net.Net import eval [as 別名]
class Solver():
    def __init__(self, args):
        # prepare shakespeare dataset
        train_iter, data_info = load_shakespeare(args.batch_size, args.bptt_len)
        self.vocab_size = data_info["vocab_size"]
        self.TEXT = data_info["TEXT"]
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.net = Net(self.vocab_size, args.embed_dim, 
                       args.hidden_dim, args.num_layers).to(self.device)
        self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=1) # <pad>: 1
        self.optim   = torch.optim.Adam(self.net.parameters(), args.lr)
        
        self.args = args
        self.train_iter = train_iter
        
        if not os.path.exists(args.ckpt_dir):
            os.makedirs(args.ckpt_dir)
        
    def fit(self):
        args = self.args

        for epoch in range(args.max_epochs):
            self.net.train()
            for step, inputs in enumerate(self.train_iter):
                X = inputs.text.to(self.device)
                y = inputs.target.to(self.device)

                out, _ = self.net(X)
                loss = self.loss_fn(out, y.view(-1))

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            if (epoch+1) % args.print_every == 0:
                text = self.sample(args.sample_length, args.sample_prime)
                print("Epoch [{}/{}] loss: {:.3f}"
                    .format(epoch+1, args.max_epochs, loss.item()/args.bptt_len))
                print(text, "\n")
                self.save(args.ckpt_dir, args.ckpt_name, epoch+1)

    def sample(self, length, prime="The"):
        args = self.args

        self.net.eval()
        samples = list(prime)

        # convert prime string to torch.LongTensor type
        prime = self.TEXT.process(prime, device=self.device, train=False)
        
        # sample character indices
        indices = self.net.sample(prime, length)

        # convert char indices to string type
        for index in indices:
            out = self.TEXT.vocab.itos[index.item()]
            samples.append(out.replace("<eos>", "\n"))
        
        self.TEXT.sequential = True

        return "".join(samples)
                
    def save(self, ckpt_dir, ckpt_name, global_step):
        save_path = os.path.join(
            ckpt_dir, "{}_{}.pth".format(ckpt_name, global_step))
        torch.save(self.net.state_dict(), save_path)
開發者ID:muncok,項目名稱:pytorch-exercise,代碼行數:70,代碼來源:solver.py


注:本文中的net.Net.eval方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。