当前位置: 首页>>代码示例>>Python>>正文


Python torch.narrow方法代码示例

本文整理汇总了Python中torch.narrow方法的典型用法代码示例。如果您正苦于以下问题:Python torch.narrow方法的具体用法?Python torch.narrow怎么用?Python torch.narrow使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.narrow方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def forward(self, x):
    x = x.reshape(280, 280, 4)
    x = torch.narrow(x, dim=2, start=3, length=1)
    x = x.reshape(1, 1, 280, 280)
    x = F.avg_pool2d(x, 10, stride=10)
    x = x / 255
    x = (x - MEAN) / STANDARD_DEVIATION

    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.softmax(x, dim=1)
    return output 
开发者ID:elliotwaite,项目名称:pytorch-to-javascript-with-onnx-js,代码行数:22,代码来源:inference_mnist_model.py

示例2: image_eval

# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def image_eval(batch_size, context, similarity,
               true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy, true_prod_attributes,
               false_prod_texts, false_prod_text_lengths, false_prod_images, false_images_num, false_prod_taxonomies,
               false_prod_attributess):
    true_cos_sim = similarity(context, true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy,
                              true_prod_attributes)
    # mask
    mask = torch.zeros(batch_size, DatasetOption.num_neg_images + 1, dtype=torch.long).to(GlobalOption.device)
    mask.scatter_(1, false_images_num.view(-1, 1), 1)
    mask = 1 - mask.cumsum(dim=1)
    mask = torch.narrow(mask, 1, 0, DatasetOption.num_neg_images)
    mask.transpose_(0, 1)

    # number of negative images similarity greater than positive images
    gt = torch.zeros(batch_size, dtype=torch.long).to(GlobalOption.device)

    for i in range(DatasetOption.num_neg_images):
        false_cos_sim = similarity(context, false_prod_texts[i], false_prod_text_lengths[i], false_prod_images[i],
                                   false_prod_taxonomies[i], false_prod_attributess[i])
        gt += (false_cos_sim > true_cos_sim).long() * mask[i]
    return gt 
开发者ID:ChenTsuei,项目名称:UMD,代码行数:23,代码来源:image_eval.py

示例3: image_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def image_loss(batch_size, context, similarity,
               true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy, true_prod_attributes,
               false_prod_texts, false_prod_text_lengths, false_prod_images, false_images_num, false_prod_taxonomies,
               false_prod_attributess):
    ones = torch.ones(batch_size).to(GlobalOption.device)
    zeros = torch.zeros(batch_size).to(GlobalOption.device)
    true_cos_sim = similarity(context, true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy,
                              true_prod_attributes)
    # mask
    mask = torch.zeros(batch_size, DatasetOption.num_neg_images + 1, dtype=torch.long).to(GlobalOption.device)
    mask.scatter_(1, false_images_num.view(-1, 1), 1)
    mask = 1 - mask.cumsum(dim=1)
    mask = torch.narrow(mask, 1, 0, DatasetOption.num_neg_images)
    mask.transpose_(0, 1)

    losses = []
    for i in range(DatasetOption.num_neg_images):
        false_cos_sim = similarity(context, false_prod_texts[i], false_prod_text_lengths[i], false_prod_images[i],
                                   false_prod_taxonomies[i], false_prod_attributess[i])
        loss = torch.max(zeros, ones - true_cos_sim + false_cos_sim)
        losses.append(loss)
    losses = torch.stack(losses)
    # losses: (#img_per_utter, batch)
    loss = losses.masked_select(mask.byte()).mean()
    return loss 
开发者ID:ChenTsuei,项目名称:UMD,代码行数:27,代码来源:loss.py

示例4: slice_axis

# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def slice_axis(data, axis, begin, end):
    return th.narrow(data, axis, begin, end - begin) 
开发者ID:dmlc,项目名称:dgl,代码行数:4,代码来源:tensor.py

示例5: get_keys

# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def get_keys(self, type_op, n_instances=1, remove=True):
        """
        Return FSS keys primitives

        Args:
            type_op: fss_eq, fss_comp, or xor_add_couple
            n_instances: how many primitives to retrieve. Comparison is pointwise so this is
                convenient: for any matrice of size nxm I can unstack n*m elements for the
                comparison
            remove: if true, pop out the primitive. If false, only read it. Read mode is
                needed because we're working on virtual workers and they need to gather
                a some point and then re-access the keys.
        """
        primitive_stack = getattr(self, type_op)

        available_instances = len(primitive_stack[0]) if len(primitive_stack) > 0 else -1
        if available_instances >= n_instances:
            keys = []
            # We iterate on the different elements that constitute a given primitive, for
            # example of the beaver triples, you would have 3 elements.
            for i, prim in enumerate(primitive_stack):
                # We're selecting on the last dimension of the tensor because it's simpler for
                # generating those primitives in crypto protocols
                # th.narrow(dim, index_start, length)
                keys.append(th.narrow(prim, -1, 0, n_instances))
                if remove:
                    length = prim.shape[-1] - n_instances
                    primitive_stack[i] = th.narrow(prim, -1, n_instances, length)

            return keys
        else:
            raise EmptyCryptoPrimitiveStoreError(self, type_op, available_instances, n_instances) 
开发者ID:OpenMined,项目名称:PySyft,代码行数:34,代码来源:primitives.py

示例6: visit_slice

# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def visit_slice(self, op: Slice, network: PyTorchNetwork):
        self._add_computation(lambda a: a, op.o_output, (op.i_data,))
        for dim, start, end in zip(op.axes.get_value(), op.starts.get_value(), op.ends.get_value()):
            self._add_computation(torch.narrow, op.o_output, (op.o_output,)) 
开发者ID:deep500,项目名称:deep500,代码行数:6,代码来源:pytorch_visitor.py


注:本文中的torch.narrow方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。