本文整理汇总了Python中torch.narrow方法的典型用法代码示例。如果您正苦于以下问题:Python torch.narrow方法的具体用法?Python torch.narrow怎么用?Python torch.narrow使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.narrow方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def forward(self, x):
x = x.reshape(280, 280, 4)
x = torch.narrow(x, dim=2, start=3, length=1)
x = x.reshape(1, 1, 280, 280)
x = F.avg_pool2d(x, 10, stride=10)
x = x / 255
x = (x - MEAN) / STANDARD_DEVIATION
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.softmax(x, dim=1)
return output
示例2: image_eval
# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def image_eval(batch_size, context, similarity,
true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy, true_prod_attributes,
false_prod_texts, false_prod_text_lengths, false_prod_images, false_images_num, false_prod_taxonomies,
false_prod_attributess):
true_cos_sim = similarity(context, true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy,
true_prod_attributes)
# mask
mask = torch.zeros(batch_size, DatasetOption.num_neg_images + 1, dtype=torch.long).to(GlobalOption.device)
mask.scatter_(1, false_images_num.view(-1, 1), 1)
mask = 1 - mask.cumsum(dim=1)
mask = torch.narrow(mask, 1, 0, DatasetOption.num_neg_images)
mask.transpose_(0, 1)
# number of negative images similarity greater than positive images
gt = torch.zeros(batch_size, dtype=torch.long).to(GlobalOption.device)
for i in range(DatasetOption.num_neg_images):
false_cos_sim = similarity(context, false_prod_texts[i], false_prod_text_lengths[i], false_prod_images[i],
false_prod_taxonomies[i], false_prod_attributess[i])
gt += (false_cos_sim > true_cos_sim).long() * mask[i]
return gt
示例3: image_loss
# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def image_loss(batch_size, context, similarity,
true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy, true_prod_attributes,
false_prod_texts, false_prod_text_lengths, false_prod_images, false_images_num, false_prod_taxonomies,
false_prod_attributess):
ones = torch.ones(batch_size).to(GlobalOption.device)
zeros = torch.zeros(batch_size).to(GlobalOption.device)
true_cos_sim = similarity(context, true_prod_text, true_prod_text_length, true_prod_image, true_prod_taxonomy,
true_prod_attributes)
# mask
mask = torch.zeros(batch_size, DatasetOption.num_neg_images + 1, dtype=torch.long).to(GlobalOption.device)
mask.scatter_(1, false_images_num.view(-1, 1), 1)
mask = 1 - mask.cumsum(dim=1)
mask = torch.narrow(mask, 1, 0, DatasetOption.num_neg_images)
mask.transpose_(0, 1)
losses = []
for i in range(DatasetOption.num_neg_images):
false_cos_sim = similarity(context, false_prod_texts[i], false_prod_text_lengths[i], false_prod_images[i],
false_prod_taxonomies[i], false_prod_attributess[i])
loss = torch.max(zeros, ones - true_cos_sim + false_cos_sim)
losses.append(loss)
losses = torch.stack(losses)
# losses: (#img_per_utter, batch)
loss = losses.masked_select(mask.byte()).mean()
return loss
示例4: slice_axis
# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def slice_axis(data, axis, begin, end):
return th.narrow(data, axis, begin, end - begin)
示例5: get_keys
# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def get_keys(self, type_op, n_instances=1, remove=True):
"""
Return FSS keys primitives
Args:
type_op: fss_eq, fss_comp, or xor_add_couple
n_instances: how many primitives to retrieve. Comparison is pointwise so this is
convenient: for any matrice of size nxm I can unstack n*m elements for the
comparison
remove: if true, pop out the primitive. If false, only read it. Read mode is
needed because we're working on virtual workers and they need to gather
a some point and then re-access the keys.
"""
primitive_stack = getattr(self, type_op)
available_instances = len(primitive_stack[0]) if len(primitive_stack) > 0 else -1
if available_instances >= n_instances:
keys = []
# We iterate on the different elements that constitute a given primitive, for
# example of the beaver triples, you would have 3 elements.
for i, prim in enumerate(primitive_stack):
# We're selecting on the last dimension of the tensor because it's simpler for
# generating those primitives in crypto protocols
# th.narrow(dim, index_start, length)
keys.append(th.narrow(prim, -1, 0, n_instances))
if remove:
length = prim.shape[-1] - n_instances
primitive_stack[i] = th.narrow(prim, -1, n_instances, length)
return keys
else:
raise EmptyCryptoPrimitiveStoreError(self, type_op, available_instances, n_instances)
示例6: visit_slice
# 需要导入模块: import torch [as 别名]
# 或者: from torch import narrow [as 别名]
def visit_slice(self, op: Slice, network: PyTorchNetwork):
self._add_computation(lambda a: a, op.o_output, (op.i_data,))
for dim, start, end in zip(op.axes.get_value(), op.starts.get_value(), op.ends.get_value()):
self._add_computation(torch.narrow, op.o_output, (op.o_output,))