當前位置: 首頁>>代碼示例>>Python>>正文


Python nn.MSELoss方法代碼示例

本文整理匯總了Python中torch.nn.MSELoss方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.MSELoss方法的具體用法?Python nn.MSELoss怎麽用?Python nn.MSELoss使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.nn的用法示例。


在下文中一共展示了nn.MSELoss方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self, config, net):
        self.log_dir = config.log_dir
        self.model_dir = config.model_dir
        self.net = net
        self.clock = TrainClock()
        self.device = config.device

        self.use_triplet = config.use_triplet
        self.use_footvel_loss = config.use_footvel_loss

        # set loss function
        self.mse = nn.MSELoss()
        self.tripletloss = nn.TripletMarginLoss(margin=config.triplet_margin)
        self.triplet_weight = config.triplet_weight
        self.foot_idx = config.foot_idx
        self.footvel_loss_weight = config.footvel_loss_weight

        # set optimizer
        self.optimizer = optim.Adam(self.net.parameters(), config.lr)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, 0.99) 
開發者ID:ChrisWu1997,項目名稱:2D-Motion-Retargeting,代碼行數:22,代碼來源:base_agent.py

示例2: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode) 
開發者ID:Mingtzge,項目名稱:2019-CCF-BDCI-OCR-MCZJ-OCR-IdentificationIDElement,代碼行數:25,代碼來源:networks.py

示例3: print_network

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


##############################################################################
# Classes
##############################################################################


# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input 
開發者ID:joelmoniz,項目名稱:DepthNets,代碼行數:19,代碼來源:networks.py

示例4: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
        super(GANLoss, self).__init__()
        self.gan_type = gan_type.lower()
        self.real_label_val = real_label_val
        self.fake_label_val = fake_label_val

        if self.gan_type == 'gan' or self.gan_type == 'ragan':
            self.loss = nn.BCEWithLogitsLoss()
        elif self.gan_type == 'lsgan':
            self.loss = nn.MSELoss()
        elif self.gan_type == 'wgan-gp':
            def wgan_loss(input, target):
                # target is boolean
                return -1 * input.mean() if target else input.mean()

            self.loss = wgan_loss
        else:
            raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 
開發者ID:cszn,項目名稱:KAIR,代碼行數:20,代碼來源:loss.py

示例5: define_loss

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def define_loss(self):
        G_lossfn_type = self.opt_train['G_lossfn_type']
        if G_lossfn_type == 'l1':
            self.G_lossfn = nn.L1Loss().to(self.device)
        elif G_lossfn_type == 'l2':
            self.G_lossfn = nn.MSELoss().to(self.device)
        elif G_lossfn_type == 'l2sum':
            self.G_lossfn = nn.MSELoss(reduction='sum').to(self.device)
        elif G_lossfn_type == 'ssim':
            self.G_lossfn = SSIMLoss().to(self.device)
        else:
            raise NotImplementedError('Loss type [{:s}] is not found.'.format(G_lossfn_type))
        self.G_lossfn_weight = self.opt_train['G_lossfn_weight']

    # ----------------------------------------
    # define optimizer
    # ---------------------------------------- 
開發者ID:cszn,項目名稱:KAIR,代碼行數:19,代碼來源:model_plain2.py

示例6: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(
        self,
        layer_dims=[768,1],
        task_name="regression",
        **kwargs,
    ):
        super(RegressionHead, self).__init__()
        # num_labels could in most cases also be automatically retrieved from the data processor
        self.layer_dims = layer_dims
        self.feed_forward = FeedForwardBlock(self.layer_dims)
        # num_labels is being set to 2 since it is being hijacked to store the scaling factor and the mean
        self.num_labels = 2
        self.ph_output_type = "per_sequence_continuous"
        self.model_type = "regression"
        self.loss_fct = MSELoss(reduction="none")
        self.task_name = task_name
        self.generate_config() 
開發者ID:deepset-ai,項目名稱:FARM,代碼行數:19,代碼來源:prediction_head.py

示例7: test_save_load_network

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def test_save_load_network(self):
        local_net = Net_arch(self.hp)
        self.loss_f = nn.MSELoss()
        local_model = Model(self.hp, local_net, self.loss_f)

        self.model.save_network(self.logger)
        save_filename = "%s_%d.pt" % (self.hp.log.name, self.model.step)
        save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
        self.hp.load.network_chkpt_path = save_path

        assert os.path.exists(save_path) and os.path.isfile(save_path)
        assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
            self.hp.log.log_file_path
        )

        local_model.load_network(logger=self.logger)
        parameters = zip(
            list(local_model.net.parameters()), list(self.model.net.parameters())
        )
        for load, origin in parameters:
            assert (load == origin).all() 
開發者ID:ryul99,項目名稱:pytorch-project-template,代碼行數:23,代碼來源:model_test.py

示例8: test_save_load_state

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def test_save_load_state(self):
        local_net = Net_arch(self.hp)
        self.loss_f = nn.MSELoss()
        local_model = Model(self.hp, local_net, self.loss_f)

        self.model.save_training_state(self.logger)
        save_filename = "%s_%d.state" % (self.hp.log.name, self.model.step)
        save_path = os.path.join(self.hp.log.chkpt_dir, save_filename)
        self.hp.load.resume_state_path = save_path

        assert os.path.exists(save_path) and os.path.isfile(save_path)
        assert os.path.exists(self.hp.log.log_file_path) and os.path.isfile(
            self.hp.log.log_file_path
        )

        local_model.load_training_state(logger=self.logger)
        parameters = zip(
            list(local_model.net.parameters()), list(self.model.net.parameters())
        )
        for load, origin in parameters:
            assert (load == origin).all()
        assert local_model.epoch == self.model.epoch
        assert local_model.step == self.model.step 
開發者ID:ryul99,項目名稱:pytorch-project-template,代碼行數:25,代碼來源:model_test.py

示例9: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self, discriminator, d_optimizer, size_average=True,
                 loss='L2', batch_acum=1, device='cpu'):
        super().__init__()
        self.discriminator = discriminator
        self.d_optimizer = d_optimizer
        self.batch_acum = batch_acum
        if loss == 'L2':
            self.loss = nn.MSELoss(size_average)
            self.labels = [1, -1, 0]
        elif loss == 'BCE':
            self.loss = nn.BCEWithLogitsLoss()
            self.labels = [1, 0, 1]
        elif loss == 'Hinge':
            self.loss = None
        else:
            raise ValueError('Urecognized loss: {}'.format(loss))
        self.device = device 
開發者ID:santi-pdp,項目名稱:pase,代碼行數:19,代碼來源:losses.py

示例10: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self, config=None):
        super(CapsNet, self).__init__()
        if config:
            self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
            self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
                                                config.pc_kernel_size, config.pc_num_routes)
            self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
                                            config.dc_out_channels)
            self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)
        else:
            self.conv_layer = ConvLayer()
            self.primary_capsules = PrimaryCaps()
            self.digit_capsules = DigitCaps()
            self.decoder = Decoder()

        self.mse_loss = nn.MSELoss() 
開發者ID:jindongwang,項目名稱:Pytorch-CapsuleNet,代碼行數:18,代碼來源:capsnet.py

示例11: train_qf

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def train_qf(self, expected_qval, obs_val, actions_val):
        """
        """
        obs = Variable(torch.from_numpy(obs_val)).type(
            torch.FloatTensor)
        actions = Variable(torch.from_numpy(actions_val)).type(
            torch.FloatTensor)
        expected_q = Variable(torch.from_numpy(expected_qval)).type(
            torch.FloatTensor)

        q_vals = self.qf(obs, actions)

        # Define loss function
        loss_fn = nn.MSELoss()
        loss = loss_fn(q_vals, expected_q)

        # Backpropagation and gradient descent
        self.qf_optimizer.zero_grad()
        loss.backward()
        self.qf_optimizer.step()

        return loss.data.numpy() 
開發者ID:nosyndicate,項目名稱:pytorchrl,代碼行數:24,代碼來源:svg.py

示例12: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self, dim, dropout=0.2, slope=0.0):
        super(SDAE, self).__init__()
        self.in_dim = dim[0]
        self.nlayers = len(dim)-1
        self.reluslope = slope
        self.enc, self.dec = [], []
        for i in range(self.nlayers):
            self.enc.append(nn.Linear(dim[i], dim[i+1]))
            setattr(self, 'enc_{}'.format(i), self.enc[-1])
            self.dec.append(nn.Linear(dim[i+1], dim[i]))
            setattr(self, 'dec_{}'.format(i), self.dec[-1])
        self.base = []
        for i in range(self.nlayers):
            self.base.append(nn.Sequential(*self.enc[:i]))
        self.dropmodule1 = nn.Dropout(p=dropout)
        self.dropmodule2 = nn.Dropout(p=dropout)
        self.loss = nn.MSELoss(size_average=True)

        # initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.normal(m.weight, std=1e-2)
                if m.bias.data is not None:
                    init.constant(m.bias, 0) 
開發者ID:shahsohil,項目名稱:DCC,代碼行數:26,代碼來源:SDAE.py

示例13: run_rmse_net

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def run_rmse_net(model, variables, X_train, Y_train):
    opt = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(1000):
        opt.zero_grad()
        model.train()
        train_loss = nn.MSELoss()(
            model(variables['X_train_'])[0], variables['Y_train_'])
        train_loss.backward()
        opt.step()

        model.eval()
        test_loss = nn.MSELoss()(
            model(variables['X_test_'])[0], variables['Y_test_'])

        print(i, train_loss.data[0], test_loss.data[0])

    model.eval()
    model.set_sig(variables['X_train_'], variables['Y_train_'])

    return model 
開發者ID:locuslab,項目名稱:e2e-model-learning,代碼行數:23,代碼來源:nets.py

示例14: BPDA_attack

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def BPDA_attack(image,target, model, step_size = 1., iterations = 10, linf=False, transform_func=identity_transform):
    target = label2tensor(target)
    adv = image.detach().numpy()
    adv = torch.from_numpy(adv)
    adv.requires_grad_()
    for _ in range(iterations):
        adv_def = transform_func(adv)
        adv_def.requires_grad_()
        l2 = nn.MSELoss()
        loss = l2(0, adv_def)
        loss.backward()
        g = get_cw_grad(adv_def, image, target, model)
        if linf:
            g = torch.sign(g)
        print(g.numpy().sum())
        adv = adv.detach().numpy() - step_size * g.numpy()
        adv = clip_bound(adv)
        adv = torch.from_numpy(adv)
        adv.requires_grad_()
        if linf:
            print('label', torch.argmax(model(adv)), 'linf', torch.max(torch.abs(adv - image)).detach().numpy())
        else:
            print('label', torch.argmax(model(adv)), 'l2', l2_norm(adv, image))
    return adv.detach().numpy() 
開發者ID:DSE-MSU,項目名稱:DeepRobust,代碼行數:26,代碼來源:BPDA.py

示例15: __init__

# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import MSELoss [as 別名]
def __init__(self):
    self.bce_loss = nn.BCELoss()
    self.mse_loss = nn.MSELoss()
    self.bce_results = []
    self.mse_results = [] 
開發者ID:jthsieh,項目名稱:DDPAE-video-prediction,代碼行數:7,代碼來源:metrics.py


注:本文中的torch.nn.MSELoss方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。