當前位置: 首頁>>代碼示例>>Python>>正文


Python Net.state_dict方法代碼示例

本文整理匯總了Python中net.Net.state_dict方法的典型用法代碼示例。如果您正苦於以下問題:Python Net.state_dict方法的具體用法?Python Net.state_dict怎麽用?Python Net.state_dict使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在net.Net的用法示例。


在下文中一共展示了Net.state_dict方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: Solver

# 需要導入模塊: from net import Net [as 別名]
# 或者: from net.Net import state_dict [as 別名]
class Solver():
    def __init__(self, args):      
        # prepare a datasets
        self.train_data = Dataset(train=True,
                                  data_root=args.data_root,
                                  size=args.image_size)
        self.test_data  = Dataset(train=False,
                                  data_root=args.data_root,
                                  size=args.image_size)
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=args.batch_size,
                                       num_workers=1,
                                       shuffle=True, drop_last=True)
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.net     = Net().to(self.device)
        self.loss_fn = torch.nn.L1Loss()
        self.optim   = torch.optim.Adam(self.net.parameters(), args.lr)
        
        self.args = args
        
        if not os.path.exists(args.ckpt_dir):
            os.makedirs(args.ckpt_dir)
        
    def fit(self):
        args = self.args

        for epoch in range(args.max_epochs):
            self.net.train()
            for step, inputs in enumerate(self.train_loader):
                gt_gray = inputs[0].to(self.device)
                gt_ab   = inputs[1].to(self.device)
                
                pred_ab = self.net(gt_gray)
                loss = self.loss_fn(pred_ab, gt_ab)
                
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

            if (epoch+1) % args.print_every == 0:
                print("Epoch [{}/{}] loss: {:.6f}".format(epoch+1, args.max_epochs, loss.item()))
                self.save(args.ckpt_dir, args.ckpt_name, epoch+1)

    def save(self, ckpt_dir, ckpt_name, global_step):
        save_path = os.path.join(
            ckpt_dir, "{}_{}.pth".format(ckpt_name, global_step))
        torch.save(self.net.state_dict(), save_path)
開發者ID:muncok,項目名稱:pytorch-exercise,代碼行數:51,代碼來源:solver.py

示例2: train

# 需要導入模塊: from net import Net [as 別名]
# 或者: from net.Net import state_dict [as 別名]
def train(args):
    # prepare the MNIST dataset
    train_dataset = datasets.MNIST(root="./data/",
                                   train=True, 
                                   transform=transforms.ToTensor(),
                                   download=True)

    test_dataset = datasets.MNIST(root="./data/",
                                  train=False, 
                                  transform=transforms.ToTensor())

    # create the data loader
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size, 
                              shuffle=True, drop_last=True)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=args.batch_size, 
                             shuffle=False)

    
    # turn on the CUDA if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    net = Net().to(device)
    loss_op = nn.CrossEntropyLoss()
    optim   = torch.optim.Adam(net.parameters(), lr=args.lr)

    for epoch in range(args.max_epochs):
        net.train()
        for step, inputs in enumerate(train_loader):
            images = inputs[0].to(device)
            labels = inputs[1].to(device)
            
            # forward-propagation
            outputs = net(images)
            loss = loss_op(outputs, labels)
            
            # back-propagation
            optim.zero_grad()
            loss.backward()
            optim.step()

        acc = evaluate(net, test_loader, device)
        print("Epoch [{}/{}] loss: {:.5f} test acc: {:.3f}"
              .format(epoch+1, args.max_epochs, loss.item(), acc))

    torch.save(net.state_dict(), "mnist-final.pth")
開發者ID:muncok,項目名稱:pytorch-exercise,代碼行數:50,代碼來源:train.py


注:本文中的net.Net.state_dict方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。