當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.narrow方法代碼示例

本文整理匯總了Python中torch.narrow方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.narrow方法的具體用法?Python torch.narrow怎麽用?Python torch.narrow使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.narrow方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import narrow [as 別名]
def forward(self, x):
    x = x.reshape(280, 280, 4)
    x = torch.narrow(x, dim=2, start=3, length=1)
    x = x.reshape(1, 1, 280, 280)
    x = F.avg_pool2d(x, 10, stride=10)
    x = x / 255
    x = (x - MEAN) / STANDARD_DEVIATION

    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.softmax(x, dim=1)
    return output 
開發者ID:elliotwaite,項目名稱:pytorch-to-javascript-with-onnx-js,代碼行數:22,代碼來源:inference_mnist_model.py

示例2: image_eval

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import narrow [as 別名]
def image_eval(batch_size, context, similarity,
               true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy, true_prod_attributes,
               false_prod_texts, false_prod_text_lengths, false_prod_images, false_images_num, false_prod_taxonomies,
               false_prod_attributess):
    true_cos_sim = similarity(context, true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy,
                              true_prod_attributes)
    # mask
    mask = torch.zeros(batch_size, DatasetOption.num_neg_images + 1, dtype=torch.long).to(GlobalOption.device)
    mask.scatter_(1, false_images_num.view(-1, 1), 1)
    mask = 1 - mask.cumsum(dim=1)
    mask = torch.narrow(mask, 1, 0, DatasetOption.num_neg_images)
    mask.transpose_(0, 1)

    # number of negative images similarity greater than positive images
    gt = torch.zeros(batch_size, dtype=torch.long).to(GlobalOption.device)

    for i in range(DatasetOption.num_neg_images):
        false_cos_sim = similarity(context, false_prod_texts[i], false_prod_text_lengths[i], false_prod_images[i],
                                   false_prod_taxonomies[i], false_prod_attributess[i])
        gt += (false_cos_sim > true_cos_sim).long() * mask[i]
    return gt 
開發者ID:ChenTsuei,項目名稱:UMD,代碼行數:23,代碼來源:image_eval.py

示例3: image_loss

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import narrow [as 別名]
def image_loss(batch_size, context, similarity,
               true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy, true_prod_attributes,
               false_prod_texts, false_prod_text_lengths, false_prod_images, false_images_num, false_prod_taxonomies,
               false_prod_attributess):
    ones = torch.ones(batch_size).to(GlobalOption.device)
    zeros = torch.zeros(batch_size).to(GlobalOption.device)
    true_cos_sim = similarity(context, true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy,
                              true_prod_attributes)
    # mask
    mask = torch.zeros(batch_size, DatasetOption.num_neg_images + 1, dtype=torch.long).to(GlobalOption.device)
    mask.scatter_(1, false_images_num.view(-1, 1), 1)
    mask = 1 - mask.cumsum(dim=1)
    mask = torch.narrow(mask, 1, 0, DatasetOption.num_neg_images)
    mask.transpose_(0, 1)

    losses = []
    for i in range(DatasetOption.num_neg_images):
        false_cos_sim = similarity(context, false_prod_texts[i], false_prod_text_lengths[i], false_prod_images[i],
                                   false_prod_taxonomies[i], false_prod_attributess[i])
        loss = torch.max(zeros, ones - true_cos_sim + false_cos_sim)
        losses.append(loss)
    losses = torch.stack(losses)
    # losses: (#img_per_utter, batch)
    loss = losses.masked_select(mask.byte()).mean()
    return loss 
開發者ID:ChenTsuei,項目名稱:UMD,代碼行數:27,代碼來源:loss.py

示例4: slice_axis

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import narrow [as 別名]
def slice_axis(data, axis, begin, end):
    return th.narrow(data, axis, begin, end - begin) 
開發者ID:dmlc,項目名稱:dgl,代碼行數:4,代碼來源:tensor.py

示例5: get_keys

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import narrow [as 別名]
def get_keys(self, type_op, n_instances=1, remove=True):
        """
        Return FSS keys primitives

        Args:
            type_op: fss_eq, fss_comp, or xor_add_couple
            n_instances: how many primitives to retrieve. Comparison is pointwise so this is
                convenient: for any matrice of size nxm I can unstack n*m elements for the
                comparison
            remove: if true, pop out the primitive. If false, only read it. Read mode is
                needed because we're working on virtual workers and they need to gather
                a some point and then re-access the keys.
        """
        primitive_stack = getattr(self, type_op)

        available_instances = len(primitive_stack[0]) if len(primitive_stack) > 0 else -1
        if available_instances >= n_instances:
            keys = []
            # We iterate on the different elements that constitute a given primitive, for
            # example of the beaver triples, you would have 3 elements.
            for i, prim in enumerate(primitive_stack):
                # We're selecting on the last dimension of the tensor because it's simpler for
                # generating those primitives in crypto protocols
                # th.narrow(dim, index_start, length)
                keys.append(th.narrow(prim, -1, 0, n_instances))
                if remove:
                    length = prim.shape[-1] - n_instances
                    primitive_stack[i] = th.narrow(prim, -1, n_instances, length)

            return keys
        else:
            raise EmptyCryptoPrimitiveStoreError(self, type_op, available_instances, n_instances) 
開發者ID:OpenMined,項目名稱:PySyft,代碼行數:34,代碼來源:primitives.py

示例6: visit_slice

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import narrow [as 別名]
def visit_slice(self, op: Slice, network: PyTorchNetwork):
        self._add_computation(lambda a: a, op.o_output, (op.i_data,))
        for dim, start, end in zip(op.axes.get_value(), op.starts.get_value(), op.ends.get_value()):
            self._add_computation(torch.narrow, op.o_output, (op.o_output,)) 
開發者ID:deep500,項目名稱:deep500,代碼行數:6,代碼來源:pytorch_visitor.py


注:本文中的torch.narrow方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。