当前位置: 首页>>代码示例>>Python>>正文


Python data.sort方法代码示例

本文整理汇总了Python中torch.utils.data.sort方法的典型用法代码示例。如果您正苦于以下问题:Python data.sort方法的具体用法?Python data.sort怎么用?Python data.sort使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.sort方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: collate_text

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_text(data):
    if data[0][0] is not None:
        data.sort(key=lambda x: len(x[0]), reverse=True)
    captions, cap_bows, idxs, cap_ids = zip(*data)

    if captions[0] is not None:
        # Merge captions (convert tuple of 1D tensor to 2D tensor)
        lengths = [len(cap) for cap in captions]
        target = torch.zeros(len(captions), max(lengths)).long()
        words_mask = torch.zeros(len(captions), max(lengths))
        for i, cap in enumerate(captions):
            end = lengths[i]
            target[i, :end] = cap[:end]
            words_mask[i, :end] = 1.0
    else:
        target = None
        lengths = None
        words_mask = None


    cap_bows = torch.stack(cap_bows, 0) if cap_bows[0] is not None else None

    text_data = (target, cap_bows, lengths, words_mask)

    return text_data, idxs, cap_ids 
开发者ID:danieljf24,项目名称:dual_encoding,代码行数:27,代码来源:data_provider.py

示例2: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):

    # Sort a data list by caption length (descending order).
    # data.sort(key=lambda x: len(x[2]), reverse=True)
    image_input, captions, ingrs_gt, img_id, path, pad_value = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).

    image_input = torch.stack(image_input, 0)
    ingrs_gt = torch.stack(ingrs_gt, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.ones(len(captions), max(lengths)).long()*pad_value[0]

    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]

    return image_input, targets, ingrs_gt, img_id, path 
开发者ID:facebookresearch,项目名称:inversecooking,代码行数:22,代码来源:data_loader.py

示例3: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences):
        lengths = [len(sequence) for sequence in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs

    data.sort(key=lambda x: len(x[0]), reverse=True)
    src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, oov_lst = zip(*data)

    src_seqs = merge(src_seqs)
    ext_src_seqs = merge(ext_src_seqs)
    trg_seqs = merge(trg_seqs)
    ext_trg_seqs = merge(ext_trg_seqs)

    return src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, oov_lst 
开发者ID:seanie12,项目名称:neural-question-generation,代码行数:20,代码来源:data_utils.py

示例4: collate_fn_tag

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn_tag(data):
    def merge(sequences):
        lengths = [len(sequence) for sequence in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs

    data.sort(key=lambda x: len(x[0]), reverse=True)
    src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, oov_lst, tag_seqs = zip(
        *data)

    src_seqs = merge(src_seqs)
    ext_src_seqs = merge(ext_src_seqs)
    trg_seqs = merge(trg_seqs)
    ext_trg_seqs = merge(ext_trg_seqs)
    tag_seqs = merge(tag_seqs)

    assert src_seqs.size(1) == tag_seqs.size(
        1), "length of tokens and tags should be equal"

    return src_seqs, ext_src_seqs, trg_seqs, ext_trg_seqs, tag_seqs, oov_lst 
开发者ID:seanie12,项目名称:neural-question-generation,代码行数:25,代码来源:data_utils.py

示例5: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences):
        lengths = [len(seq) for seq in sequences]

        padded_seqs = torch.zeros(len(sequences), max(lengths)).long().cuda()  # padding index 0
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    data.sort(key=lambda x: len(x["X"]), reverse=True)  # sort by source seq

    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

    # input
    x_batch, _ = merge(item_info['X'])
    y_batch = item_info['y']

    return x_batch, torch.tensor(y_batch, device='cuda', dtype=torch.long) 
开发者ID:uber-research,项目名称:PPLM,代码行数:23,代码来源:gpt2tunediscrim.py

示例6: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    """Build mini-batch tensors from a list of (image, caption) tuples.
    Args:
        data: list of (image, caption) tuple.
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions, ids, img_ids = list(zip(*data))

    # Merge images (convert tuple of 3D tensor to 4D tensor)
    images = torch.stack(images, 0)

    # Merge captions (convert tuple of 1D tensor to 2D tensor)
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]

    return images, targets, lengths, ids 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:29,代码来源:data.py

示例7: collate_fn_train_text

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn_train_text(data):
    """Build mini-batch tensors from a list of (image, caption) tuples.
    Args:
        data: list of (image, caption) tuple.
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions, ids, img_ids, extended_captions = list(zip(*data))

    # Merge images (convert tuple of 3D tensor to 4D tensor)
    images = torch.stack(images, 0)

    # Merget captions (convert tuple of 1D tensor to 2D tensor)
    pn_number = len(extended_captions[0]) + 1
    lengths = list()
    for cap in captions:
        lengths.extend([len(cap)] * pn_number)
    targets = torch.zeros(len(captions) * pn_number, max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i * pn_number]
        targets[i * pn_number, :end] = cap[:end]
        for i_, cap_ in enumerate(extended_captions[i]):
            targets[i * pn_number + i_ + 1, :end] = cap_[:end]
    return images, targets, lengths, ids 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:33,代码来源:data.py

示例8: collate_fn_test_text

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn_test_text(data):
    # Sort a data list by caption length
    data.sort(key=lambda x: len(x[0]), reverse=True)
    captions, ids = list(zip(*data))

    # Merge captions (convert tuple of 1D tensor to 2D tensor)
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]

    return None, targets, lengths, ids 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:15,代码来源:data.py

示例9: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences):
        lengths = [len(seq) for seq in sequences]
        padded_seqs = torch.ones(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[-1]), reverse=True)
    # seperate source and target sequences
    src_seqs, trg_seqs, ind_seqs, max_len, src_plain,trg_plain = zip(*data)
    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs)
    trg_seqs, trg_lengths = merge(trg_seqs)
    ind_seqs, _ = merge(ind_seqs)
    
    src_seqs = Variable(src_seqs).transpose(0,1)
    trg_seqs = Variable(trg_seqs).transpose(0,1)
    ind_seqs = Variable(ind_seqs).transpose(0,1)

    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        ind_seqs = ind_seqs.cuda()
    return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, src_plain, trg_plain 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:29,代码来源:utils_NMT.py

示例10: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences,max_len):
        lengths = [len(seq) for seq in sequences]
        if (max_len):
            padded_seqs = torch.ones(len(sequences), max(lengths), MEM_TOKEN_SIZE).long()
            for i, seq in enumerate(sequences):
                end = lengths[i]
                padded_seqs[i,:end,:] = seq[:end]
        else:
            padded_seqs = torch.ones(len(sequences), max(lengths)).long()
            for i, seq in enumerate(sequences):
                end = lengths[i]
                padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[-1]), reverse=True)
    # seperate source and target sequences
    src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain, entity, conv_seq, kb_arr = zip(*data)
    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs,max_len)
    trg_seqs, trg_lengths = merge(trg_seqs,None)
    ind_seqs, _ = merge(ind_seqs,None)
    gete_s, _ = merge(gete_s,None)
    conv_seqs, conv_lengths = merge(conv_seq, max_len)
    
    src_seqs = Variable(src_seqs).transpose(0,1)
    trg_seqs = Variable(trg_seqs).transpose(0,1)
    ind_seqs = Variable(ind_seqs).transpose(0,1)
    gete_s = Variable(gete_s).transpose(0,1)
    conv_seqs = Variable(conv_seqs).transpose(0,1)

    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        ind_seqs = ind_seqs.cuda()
        gete_s = gete_s.cuda()
        conv_seqs = conv_seqs.cuda()
    return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain, entity, conv_seqs, conv_lengths, kb_arr 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:41,代码来源:utils_woz_mem2seq.py

示例11: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences,max_len):
        lengths = [len(seq) for seq in sequences]
        if (max_len):
            padded_seqs = torch.ones(len(sequences), max(lengths), MEM_TOKEN_SIZE).long()
            for i, seq in enumerate(sequences):
                end = lengths[i]
                padded_seqs[i,:end,:] = seq[:end]
        else:
            padded_seqs = torch.ones(len(sequences), max(lengths)).long()
            for i, seq in enumerate(sequences):
                end = lengths[i]
                padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[-1]), reverse=True)
    # seperate source and target sequences
    src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain, entity,entity_cal,entity_nav,entity_wet, conv_seq, kb_arr = zip(*data)
    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs,max_len)
    trg_seqs, trg_lengths = merge(trg_seqs,None)
    ind_seqs, _ = merge(ind_seqs,None)
    gete_s, _ = merge(gete_s,None)
    conv_seqs, conv_lengths = merge(conv_seq, max_len)
    
    src_seqs = Variable(src_seqs).transpose(0,1)
    trg_seqs = Variable(trg_seqs).transpose(0,1)
    ind_seqs = Variable(ind_seqs).transpose(0,1)
    gete_s = Variable(gete_s).transpose(0,1)
    conv_seqs = Variable(conv_seqs).transpose(0,1)

    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        ind_seqs = ind_seqs.cuda()
        gete_s = gete_s.cuda()
        conv_seqs = conv_seqs.cuda()
    return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain, entity, entity_cal, entity_nav, entity_wet, conv_seqs, conv_lengths, kb_arr 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:41,代码来源:utils_kvr_mem2seq.py

示例12: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences,max_len):
        lengths = [len(seq) for seq in sequences]
        if (max_len):
            padded_seqs = torch.ones(len(sequences), max(lengths), MEM_TOKEN_SIZE).long()
            for i, seq in enumerate(sequences):
                end = lengths[i]
                padded_seqs[i,:end,:] = seq[:end]
        else:
            padded_seqs = torch.ones(len(sequences), max(lengths)).long()
            for i, seq in enumerate(sequences):
                end = lengths[i]
                padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)
    # seperate source and target sequences
    src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain, conv_seq, ent, ID, kb_arr = zip(*data)
    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs,max_len)
    trg_seqs, trg_lengths = merge(trg_seqs,None)
    ind_seqs, _ = merge(ind_seqs,None)
    gete_s, _ = merge(gete_s,None)
    conv_seqs, conv_lengths = merge(conv_seq, max_len)
    
    src_seqs = Variable(src_seqs).transpose(0,1)
    trg_seqs = Variable(trg_seqs).transpose(0,1)
    ind_seqs = Variable(ind_seqs).transpose(0,1)
    gete_s = Variable(gete_s).transpose(0,1)
    conv_seqs = Variable(conv_seqs).transpose(0,1)

    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        ind_seqs = ind_seqs.cuda()
        gete_s = gete_s.cuda()
        conv_seqs = conv_seqs.cuda()
    return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain, conv_seqs, conv_lengths, ent, ID, kb_arr 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:41,代码来源:utils_babi_mem2seq.py

示例13: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences,max_len):
        lengths = [len(seq) for seq in sequences]
        if (max_len):
            padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        else:
            padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)
    # seperate source and target sequences
    src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain,entity,entity_cal,entity_nav,entity_wet = zip(*data)
    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs,max_len)
    trg_seqs, trg_lengths = merge(trg_seqs,None)
    ind_seqs, _ = merge(ind_seqs,None)
    gete_s, _ = merge(gete_s,None)
    
    src_seqs = Variable(src_seqs).transpose(0,1)
    trg_seqs = Variable(trg_seqs).transpose(0,1)
    ind_seqs = Variable(ind_seqs).transpose(0,1)
    gete_s = Variable(gete_s).transpose(0,1)
    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        ind_seqs = ind_seqs.cuda()
        gete_s = gete_s.cuda()
    return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain,entity,entity_cal,entity_nav,entity_wet 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:34,代码来源:utils_kvr.py

示例14: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    def merge(sequences,max_len):
        lengths = [len(seq) for seq in sequences]
        if (max_len):
            padded_seqs = torch.ones(len(sequences), max_len[0]).long()
        else:
            padded_seqs = torch.ones(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)
    # seperate source and target sequences
    src_seqs, trg_seqs, ind_seqs, gete_s, max_len, src_plain,trg_plain = zip(*data)
    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs,max_len)
    trg_seqs, trg_lengths = merge(trg_seqs,None)
    ind_seqs, _ = merge(ind_seqs,None)
    gete_s, _ = merge(gete_s,None)
    
    src_seqs = Variable(src_seqs).transpose(0,1)
    trg_seqs = Variable(trg_seqs).transpose(0,1)
    ind_seqs = Variable(ind_seqs).transpose(0,1)
    gete_s = Variable(gete_s).transpose(0,1)
    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        ind_seqs = ind_seqs.cuda()
        gete_s = gete_s.cuda()
    return src_seqs, src_lengths, trg_seqs, trg_lengths, ind_seqs, gete_s, src_plain, trg_plain 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:34,代码来源:utils_babi.py

示例15: collate_fn

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import sort [as 别名]
def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).

    data.sort(key=lambda x: len(x[5]), reverse=True)
    img, instrs, itr_ln, ingrs, igr_ln,\
    ingr_cap, class_label, ret, one_hot_vec, food_id = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(img, 0)
    instrs = torch.stack(instrs, 0)
    itr_ln = torch.LongTensor(list(itr_ln))
    ingrs = torch.stack(ingrs, 0)
    igr_ln = torch.LongTensor(list(igr_ln))
    class_label = torch.LongTensor(list(class_label))
    ret = torch.stack(ret, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in ingr_cap]
    targets = torch.zeros(len(ingr_cap), max(lengths)).long()
    for i, cap in enumerate(ingr_cap):
        end = lengths[i]
        targets[i, :end] = cap[:end] 

    one_hot_vec = torch.stack(one_hot_vec, 0)

    return [images, instrs, itr_ln, ingrs, igr_ln, list(food_id)], \
    [images, instrs, itr_ln, ingrs, igr_ln, targets, lengths, class_label, ret, one_hot_vec] 
开发者ID:hwang1996,项目名称:ACME,代码行数:42,代码来源:data_loader.py


注:本文中的torch.utils.data.sort方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。