本文整理汇总了Python中torch.utils.data.sort方法的典型用法代码示例。如果您正苦于以下问题:Python data.sort方法的具体用法?Python data.sort怎么用?Python data.sort使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.sort方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: collate_text
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_text(data):
if data[0][0] is not None:
data.sort(key=lambda x: len(x[0]), reverse=True)
captions, cap_bows, idxs, cap_ids = zip(*data)
if captions[0] is not None:
# Merge captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
target = torch.zeros(len(captions), max(lengths)).long()
words_mask = torch.zeros(len(captions), max(lengths))
for i, cap in enumerate(captions):
end = lengths[i]
target[i, :end] = cap[:end]
words_mask[i, :end] = 1.0
else:
target = None
lengths = None
words_mask = None
cap_bows = torch.stack(cap_bows, 0) if cap_bows[0] is not None else None
text_data = (target, cap_bows, lengths, words_mask)
return text_data, idxs, cap_ids
示例2: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
# Sort a data list by caption length (descending order).
# data.sort(key=lambda x: len(x[2]), reverse=True)
image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data)
# Merge images (from tuple of 3D tensor to 4D tensor).
image_input = torch.stack(image_input, 0)
ingrs_gt = torch.stack(ingrs_gt, 0)
# Merge captions (from tuple of 1D tensor to 2D tensor).
lengths = [len(cap) for cap in captions]
targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0]
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return image_input, targets, ingrs_gt, img_id, path
示例3: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences):
lengths = [len(sequence) for sequence in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs
data.sort(key=lambda x: len(x[0]), reverse=True)
src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, oov_lst = zip(*data)
src_seqs = merge(src_seqs)
ext_src_seqs = merge(ext_src_seqs)
trg_seqs = merge(trg_seqs)
ext_trg_seqs = merge(ext_trg_seqs)
return src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, oov_lst
示例4: collate_fn_tag
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn_tag(data):
def merge(sequences):
lengths = [len(sequence) for sequence in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs
data.sort(key=lambda x: len(x[0]), reverse=True)
src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, oov_lst, tag_seqs = zip(
*data)
src_seqs = merge(src_seqs)
ext_src_seqs = merge(ext_src_seqs)
trg_seqs = merge(trg_seqs)
ext_trg_seqs = merge(ext_trg_seqs)
tag_seqs = merge(tag_seqs)
assert src_seqs.size(1) == tag_seqs.size(
1), "length of tokens and tags should be equal"
return src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, tag_seqs, oov_lst
示例5: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long().cuda() # padding index 0
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
data.sort(key=lambda x: len(x["X"]), reverse=True) # sort by source seq
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
# input
x_batch, _ = merge(item_info['X'])
y_batch = item_info['y']
return x_batch, torch.tensor(y_batch, device='cuda', dtype=torch.long)
示例6: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
"""Build mini-batch tensors from a list of (image, caption) tuples.
Args:
data: list of (image, caption) tuple.
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions, ids, img_ids = list(zip(*data))
# Merge images (convert tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0)
# Merge captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return images, targets, lengths, ids
示例7: collate_fn_train_text
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn_train_text(data):
"""Build mini-batch tensors from a list of (image, caption) tuples.
Args:
data: list of (image, caption) tuple.
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length
data.sort(key=lambda x: len(x[1]), reverse=True)
images, captions, ids, img_ids, extended_captions = list(zip(*data))
# Merge images (convert tuple of 3D tensor to 4D tensor)
images = torch.stack(images, 0)
# Merget captions (convert tuple of 1D tensor to 2D tensor)
pn_number = len(extended_captions[0]) + 1
lengths = list()
for cap in captions:
lengths.extend([len(cap)] * pn_number)
targets = torch.zeros(len(captions) * pn_number, max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i * pn_number]
targets[i * pn_number, :end] = cap[:end]
for i_, cap_ in enumerate(extended_captions[i]):
targets[i * pn_number + i_ + 1, :end] = cap_[:end]
return images, targets, lengths, ids
示例8: collate_fn_test_text
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn_test_text(data):
# Sort a data list by caption length
data.sort(key=lambda x: len(x[0]), reverse=True)
captions, ids = list(zip(*data))
# Merge captions (convert tuple of 1D tensor to 2D tensor)
lengths = [len(cap) for cap in captions]
targets = torch.zeros(len(captions), max(lengths)).long()
for i, cap in enumerate(captions):
end = lengths[i]
targets[i, :end] = cap[:end]
return None, targets, lengths, ids
示例9: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.ones(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[-1]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, ind_seqs, max_len, src_plain,trg_plain = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs)
trg_seqs, trg_lengths = merge(trg_seqs)
ind_seqs, _ = merge(ind_seqs)
src_seqs = Variable(src_seqs).transpose(0,1)
trg_seqs = Variable(trg_seqs).transpose(0,1)
ind_seqs = Variable(ind_seqs).transpose(0,1)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
ind_seqs = ind_seqs.cuda()
return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, src_plain, trg_plain
示例10: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences,max_len):
lengths = [len(seq) for seq in sequences]
if (max_len):
padded_seqs = torch.ones(len(sequences), max(lengths), MEM_TOKEN_SIZE).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i,:end,:] = seq[:end]
else:
padded_seqs = torch.ones(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[-1]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain, entity, conv_seq, kb_arr = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs,max_len)
trg_seqs, trg_lengths = merge(trg_seqs,None)
ind_seqs, _ = merge(ind_seqs,None)
gete_s, _ = merge(gete_s,None)
conv_seqs, conv_lengths = merge(conv_seq, max_len)
src_seqs = Variable(src_seqs).transpose(0,1)
trg_seqs = Variable(trg_seqs).transpose(0,1)
ind_seqs = Variable(ind_seqs).transpose(0,1)
gete_s = Variable(gete_s).transpose(0,1)
conv_seqs = Variable(conv_seqs).transpose(0,1)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
ind_seqs = ind_seqs.cuda()
gete_s = gete_s.cuda()
conv_seqs = conv_seqs.cuda()
return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain, entity, conv_seqs, conv_lengths, kb_arr
示例11: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences,max_len):
lengths = [len(seq) for seq in sequences]
if (max_len):
padded_seqs = torch.ones(len(sequences), max(lengths), MEM_TOKEN_SIZE).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i,:end,:] = seq[:end]
else:
padded_seqs = torch.ones(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[-1]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain, entity,entity_cal,entity_nav,entity_wet, conv_seq, kb_arr = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs,max_len)
trg_seqs, trg_lengths = merge(trg_seqs,None)
ind_seqs, _ = merge(ind_seqs,None)
gete_s, _ = merge(gete_s,None)
conv_seqs, conv_lengths = merge(conv_seq, max_len)
src_seqs = Variable(src_seqs).transpose(0,1)
trg_seqs = Variable(trg_seqs).transpose(0,1)
ind_seqs = Variable(ind_seqs).transpose(0,1)
gete_s = Variable(gete_s).transpose(0,1)
conv_seqs = Variable(conv_seqs).transpose(0,1)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
ind_seqs = ind_seqs.cuda()
gete_s = gete_s.cuda()
conv_seqs = conv_seqs.cuda()
return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain, entity, entity_cal, entity_nav, entity_wet, conv_seqs, conv_lengths, kb_arr
示例12: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences,max_len):
lengths = [len(seq) for seq in sequences]
if (max_len):
padded_seqs = torch.ones(len(sequences), max(lengths), MEM_TOKEN_SIZE).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i,:end,:] = seq[:end]
else:
padded_seqs = torch.ones(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[0]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain, conv_seq, ent, ID, kb_arr = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs,max_len)
trg_seqs, trg_lengths = merge(trg_seqs,None)
ind_seqs, _ = merge(ind_seqs,None)
gete_s, _ = merge(gete_s,None)
conv_seqs, conv_lengths = merge(conv_seq, max_len)
src_seqs = Variable(src_seqs).transpose(0,1)
trg_seqs = Variable(trg_seqs).transpose(0,1)
ind_seqs = Variable(ind_seqs).transpose(0,1)
gete_s = Variable(gete_s).transpose(0,1)
conv_seqs = Variable(conv_seqs).transpose(0,1)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
ind_seqs = ind_seqs.cuda()
gete_s = gete_s.cuda()
conv_seqs = conv_seqs.cuda()
return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain, conv_seqs, conv_lengths, ent, ID, kb_arr
示例13: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences,max_len):
lengths = [len(seq) for seq in sequences]
if (max_len):
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
else:
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[0]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain,entity,entity_cal,entity_nav,entity_wet = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs,max_len)
trg_seqs, trg_lengths = merge(trg_seqs,None)
ind_seqs, _ = merge(ind_seqs,None)
gete_s, _ = merge(gete_s,None)
src_seqs = Variable(src_seqs).transpose(0,1)
trg_seqs = Variable(trg_seqs).transpose(0,1)
ind_seqs = Variable(ind_seqs).transpose(0,1)
gete_s = Variable(gete_s).transpose(0,1)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
ind_seqs = ind_seqs.cuda()
gete_s = gete_s.cuda()
return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain,entity,entity_cal,entity_nav,entity_wet
示例14: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
def merge(sequences,max_len):
lengths = [len(seq) for seq in sequences]
if (max_len):
padded_seqs = torch.ones(len(sequences), max_len[0]).long()
else:
padded_seqs = torch.ones(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[0]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs,max_len)
trg_seqs, trg_lengths = merge(trg_seqs,None)
ind_seqs, _ = merge(ind_seqs,None)
gete_s, _ = merge(gete_s,None)
src_seqs = Variable(src_seqs).transpose(0,1)
trg_seqs = Variable(trg_seqs).transpose(0,1)
ind_seqs = Variable(ind_seqs).transpose(0,1)
gete_s = Variable(gete_s).transpose(0,1)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
ind_seqs = ind_seqs.cuda()
gete_s = gete_s.cuda()
return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain
示例15: collate_fn
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
"""Creates mini-batch tensors from the list of tuples (image, caption).
We should build custom collate_fn rather than using default collate_fn,
because merging caption (including padding) is not supported in default.
Args:
data: list of tuple (image, caption).
- image: torch tensor of shape (3, 256, 256).
- caption: torch tensor of shape (?); variable length.
Returns:
images: torch tensor of shape (batch_size, 3, 256, 256).
targets: torch tensor of shape (batch_size, padded_length).
lengths: list; valid length for each padded caption.
"""
# Sort a data list by caption length (descending order).
data.sort(key=lambda x: len(x[5]), reverse=True)
img, instrs, itr_ln, ingrs, igr_ln,\
ingr_cap, class_label, ret, one_hot_vec, food_id = zip(*data)
# Merge images (from tuple of 3D tensor to 4D tensor).
images = torch.stack(img, 0)
instrs = torch.stack(instrs, 0)
itr_ln = torch.LongTensor(list(itr_ln))
ingrs = torch.stack(ingrs, 0)
igr_ln = torch.LongTensor(list(igr_ln))
class_label = torch.LongTensor(list(class_label))
ret = torch.stack(ret, 0)
# Merge captions (from tuple of 1D tensor to 2D tensor).
lengths = [len(cap) for cap in ingr_cap]
targets = torch.zeros(len(ingr_cap), max(lengths)).long()
for i, cap in enumerate(ingr_cap):
end = lengths[i]
targets[i, :end] = cap[:end]
one_hot_vec = torch.stack(one_hot_vec, 0)
return [images, instrs, itr_ln, ingrs, igr_ln, list(food_id)], \
[images, instrs, itr_ln, ingrs, igr_ln, targets, lengths, class_label, ret, one_hot_vec]