當前位置: 首頁>>代碼示例>>Python>>正文


Python data.size方法代碼示例

本文整理匯總了Python中torch.utils.data.size方法的典型用法代碼示例。如果您正苦於以下問題:Python data.size方法的具體用法?Python data.size怎麽用?Python data.size使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data的用法示例。


在下文中一共展示了data.size方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        # each data is of BATCH_SIZE (default 128) samples
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                # for the first 128 batch of the epoch, show the first 8 input digits
                # with right below them the reconstructed output digits
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

########## create VAE ########## 
開發者ID:jgvfwstone,項目名稱:ArtificialIntelligenceEngines,代碼行數:24,代碼來源:main.py

示例2: test

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def test(epoch):
    model.eval()
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        if args.cuda:
            data = data.cuda()
        data = Variable(data, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
        if i == 0:
          n = min(data.size(0), 8)
          comparison = torch.cat([data[:n],
                                  recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
          save_image(comparison.data.cpu(),
                     'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss)) 
開發者ID:eelxpeng,項目名稱:UnsupervisedDeepLearning-Pytorch,代碼行數:20,代碼來源:test_vae_pytorch_example.py

示例3: test

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss)) 
開發者ID:pytorch,項目名稱:examples,代碼行數:19,代碼來源:main.py

示例4: calc_gradient_penalty

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def calc_gradient_penalty(netD, real_data, fake_data, device='cpu', pac=10, lambda_=10):
    alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
    alpha = alpha.repeat(1, pac, real_data.size(1))
    alpha = alpha.view(-1, real_data.size(1))

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    # interpolates = torch.Variable(interpolates, requires_grad=True, device=device)

    disc_interpolates = netD(interpolates)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size(), device=device),
        create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = (
        (gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1) ** 2).mean() * lambda_
    return gradient_penalty 
開發者ID:sdv-dev,項目名稱:SDGym,代碼行數:21,代碼來源:ctgan.py

示例5: __getitem__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def __getitem__(self, index):
    # get the anchor index for current sample index
    # here we set the anchor index to the last one
    # sample in this group
        minibatch_db =  [self._roidb[index]] # [self._roidb[index_ratio]]
        blobs = get_minibatch(minibatch_db, self._num_classes)
        np.random.shuffle(blobs['rois'])
        rois = torch.from_numpy(blobs['rois'][:self.max_rois_size])
        data = torch.from_numpy(blobs['data'])
        labels = torch.from_numpy(blobs['labels'])
        data_height, data_width = data.size(1), data.size(2)
        
        data = data.permute(0, 3, 1, 2).contiguous().view(3, data_height, data_width)

        info = torch.Tensor([rois.size(0), data_height, data_width])
    
        return data, rois, labels, info 
開發者ID:jd730,項目名稱:OICR-pytorch,代碼行數:19,代碼來源:roibatchLoader.py

示例6: train

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def train(self):
        self.model.train()

        train_loss = Average()
        train_acc = Accuracy()

        for data, target in self.train_loader:
            data = data.to(self.device)
            target = target.to(self.device)

            output = self.model(data)
            loss = F.cross_entropy(output, target)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            train_loss.update(loss.item(), data.size(0))
            train_acc.update(output, target)

        return train_loss, train_acc 
開發者ID:narumiruna,項目名稱:pytorch-distributed-example,代碼行數:23,代碼來源:main.py

示例7: evaluate

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def evaluate(self):
        self.model.eval()

        test_loss = Average()
        test_acc = Accuracy()

        for data, target in self.test_loader:
            data = data.to(self.device)
            target = target.to(self.device)

            output = self.model(data)
            loss = F.cross_entropy(output, target)

            test_loss.update(loss.item(), data.size(0))
            test_acc.update(output, target)

        return test_loss, test_acc 
開發者ID:narumiruna,項目名稱:pytorch-distributed-example,代碼行數:19,代碼來源:main.py

示例8: validate

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def validate(epoch, model, loader, prefix='Validation'):
    global global_step, writer

    model.eval()
    val_loss = 0

    pbar = tqdm(total=len(loader.dataset))
    pbar.set_description('Eval')
    for batch_idx, data in enumerate(loader):
        if isinstance(data, list):
            if len(data) > 1:
                cond_data = data[1].float()
                cond_data = cond_data.to(device)
            else:
                cond_data = None

            data = data[0]
        data = data.to(device)
        with torch.no_grad():
            val_loss += -model.log_probs(data, cond_data).sum().item()  # sum up batch loss
        pbar.update(data.size(0))
        pbar.set_description('Val, Log likelihood in nats: {:.6f}'.format(
            -val_loss / pbar.n))

    writer.add_scalar('validation/LL', val_loss / len(loader.dataset), epoch)

    pbar.close()
    return val_loss / len(loader.dataset) 
開發者ID:ikostrikov,項目名稱:pytorch-flows,代碼行數:30,代碼來源:main.py

示例9: __init__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def __init__(self, mode, noisy_for_train, sentiment, direction):
        self.mode = mode
        self.root = os.path.join('../data', 'yelp')
        self.noisy = self.mode == 'train' and noisy_for_train

        # Load data from domain 0 and domain 1.
        path = os.path.join(self.root, 'sentiment.{}.{}'.format(mode, sentiment))

        # Load vocabulary.
        print('----- Loading vocab -----')
        self.vocab = Vocabulary('../data/amazon/amazon.vocab')
        print('vocabulary size:', self.vocab.size)
        self.pad = self.vocab.word2id['<pad>']
        self.go = self.vocab.word2id['<go>']
        self.eos = self.vocab.word2id['<eos>']
        self.unk = self.vocab.word2id['<unk>']

        # Tokenize file content
        with open(path, 'r') as f:
            ids = []
            for line in f:
                words = ['<go>'] + line.split() + ['<eos>']
                if direction == 'forward':
                    pass
                elif direction == 'backward':
                    words.reverse()
                else:
                    raise ValueError()
                for word in words:
                    ids.append(self.vocab.word2id[word] if word in self.vocab.word2id else self.unk)
        self.ids = torch.LongTensor(ids)  # (very_long, )
        self.ids = batchify(self.ids, config.batch_size, config)  # shape = (???, batch_size) 
開發者ID:ChenWu98,項目名稱:Point-Then-Operate,代碼行數:34,代碼來源:amazon.py

示例10: batchify

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def batchify(data, bsz, args):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    if args.gpu:
        data = data.cuda()
    return data 
開發者ID:ChenWu98,項目名稱:Point-Then-Operate,代碼行數:12,代碼來源:amazon.py

示例11: __init__

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def __init__(self, mode, noisy_for_train, sentiment, direction):
        self.mode = mode
        self.root = os.path.join('../data', 'yelp')
        voc_f = os.path.join(self.root, 'yelp.vocab')
        self.noisy = self.mode == 'train' and noisy_for_train

        # Load data from domain 0 and domain 1.
        path = os.path.join(self.root, 'sentiment.{}.{}'.format(mode, sentiment))

        # Load vocabulary.
        print('----- Loading vocab -----')
        self.vocab = Vocabulary(voc_f)
        print('vocabulary size:', self.vocab.size)
        self.pad = self.vocab.word2id['<pad>']
        self.go = self.vocab.word2id['<go>']
        self.eos = self.vocab.word2id['<eos>']
        self.unk = self.vocab.word2id['<unk>']

        # Tokenize file content
        with open(path, 'r') as f:
            ids = []
            for line in f:
                words = ['<go>'] + line.split() + ['<eos>']
                if direction == 'forward':
                    pass
                elif direction == 'backward':
                    words.reverse()
                else:
                    raise ValueError()
                for word in words:
                    ids.append(self.vocab.word2id[word] if word in self.vocab.word2id else self.unk)
        self.ids = torch.LongTensor(ids)  # (very_long, )
        self.ids = batchify(self.ids, config.batch_size, config)  # shape = (, batch_size) 
開發者ID:ChenWu98,項目名稱:Point-Then-Operate,代碼行數:35,代碼來源:yelp.py

示例12: extract_and_crop_patches_by_predicted_transform

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def extract_and_crop_patches_by_predicted_transform(patches, trans, crop_size = 32):
    assert patches.size(0) == trans.size(0)
    st = int((patches.size(2) - crop_size) / 2)
    fin = st + crop_size
    rot_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1));
    if patches.is_cuda:
        rot_LAFs = rot_LAFs.cuda()
        trans = trans.cuda()
    rot_LAFs1  = torch.cat([torch.bmm(trans, rot_LAFs[:,0:2,0:2]), rot_LAFs[:,0:2,2:]], dim = 2);
    return extract_patches(patches,  rot_LAFs1, PS = patches.size(2))[:,:, st:fin, st:fin].contiguous() 
開發者ID:ducha-aiki,項目名稱:affnet,代碼行數:12,代碼來源:train_AffNet_test_on_graffity.py

示例13: extract_random_LAF

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def extract_random_LAF(data, max_rot = math.pi, max_tilt = 1.0, crop_size = 32):
    st = int((data.size(2) - crop_size)/2)
    fin = st + crop_size
    if type(max_rot) is float:
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data, max_rot)
    else:
        rot_LAFs = max_rot
        inv_rotmat = None
    aff_LAFs, inv_TA = get_random_norm_affine_LAFs(data, max_tilt);
    aff_LAFs[:,0:2,0:2] = torch.bmm(rot_LAFs[:,0:2,0:2],aff_LAFs[:,0:2,0:2])
    data_aff = extract_patches(data,  aff_LAFs, PS = data.size(2))
    data_affcrop = data_aff[:,:, st:fin, st:fin].contiguous()
    return data_affcrop, data_aff, rot_LAFs,inv_rotmat,inv_TA 
開發者ID:ducha-aiki,項目名稱:affnet,代碼行數:15,代碼來源:train_AffNet_test_on_graffity.py

示例14: load_grayscale_var

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def load_grayscale_var(fname):
    img = Image.open(fname).convert('RGB')
    img = np.mean(np.array(img), axis = 2)
    var_image = torch.autograd.Variable(torch.from_numpy(img.astype(np.float32)), volatile = True)
    var_image_reshape = var_image.view(1, 1, var_image.size(0),var_image.size(1))
    if args.cuda:
        var_image_reshape = var_image_reshape.cuda()
    return var_image_reshape 
開發者ID:ducha-aiki,項目名稱:affnet,代碼行數:10,代碼來源:train_AffNet_test_on_graffity.py

示例15: input_norm

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import size [as 別名]
def input_norm(self,x):
        flat = x.view(x.size(0), -1)
        mp = torch.mean(flat, dim=1)
        sp = torch.std(flat, dim=1) + 1e-7
        return (x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) 
開發者ID:ducha-aiki,項目名稱:affnet,代碼行數:7,代碼來源:train_OriNet_test_on_graffity.py


注:本文中的torch.utils.data.size方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。