当前位置: 首页>>代码示例>>Python>>正文


Python data.dataset方法代码示例

本文整理汇总了Python中torch.utils.data.dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.dataset方法的具体用法?Python data.dataset怎么用?Python data.dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.dataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_dataset

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def get_dataset(self, dataset_type, transforms=None):
        files = []
        # Load data for each sequence
        for file in self.files_list:
            if dataset_type == DatasetType.TRAIN and file['phase'] == 'Train':
                name = file['name']
                pair_num = file['pair_num']
                samples = file['sample']
                files_num_old = len(files)
                files.extend(self.get_files_of_taxonomy(name, samples))
                print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
                    dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
            elif dataset_type == DatasetType.TEST and file['phase'] == 'Test':
                name = file['name']
                pair_num = file['pair_num']
                samples = file['sample']
                files_num_old = len(files)
                files.extend(self.get_files_of_taxonomy(name, samples))
                print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
                    dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))

        print('[INFO] %s Complete collecting files of the dataset for %s. Total Pair Numbur: %d.\n' % (dt.now(), dataset_type.name, len(files)))
        return StereoDeblurDataset(files, transforms) 
开发者ID:sczhou,项目名称:DAVANet,代码行数:25,代码来源:data_loaders.py

示例2: val

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def val(model,dataset):

    dataset.train(False)
    model.eval()

    dataloader = data.DataLoader(dataset,
                    batch_size = opt.batch_size,
                    shuffle = False,
                    num_workers = opt.num_workers,
                    pin_memory = True
                    )
    
    predict_label_and_marked_label_list=[]
    for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)):
        title,content,label = (Variable(title[0].cuda()),Variable(title[1].cuda())),(Variable(content[0].cuda()),Variable(content[1].cuda())),Variable(label.cuda())
        score = model(title,content)
        # !TODO: 优化此处代码
        #       1. append
        #       2. for循环
        #       3. topk 代替sort

        predict = score.data.topk(5,dim=1)[1].cpu().tolist()
        true_target = label.data.float().topk(5,dim=1)#[1].cpu().tolist()#sort(dim=1,descending=True)
        true_index=true_target[1][:,:5]
        true_label=true_target[0][:,:5]
        tmp= []

        for jj in range(label.size(0)):
            true_index_=true_index[jj]
            true_label_=true_label[jj]
            true=true_index_[true_label_>0]
            tmp.append((predict[jj],true.tolist()))
        
        predict_label_and_marked_label_list.extend(tmp)
    del score

    dataset.train(True)
    model.train()
    
    scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list)
    return (scores,prec_,recall_,_ss) 
开发者ID:chenyuntc,项目名称:PyTorchText,代码行数:43,代码来源:main-all.py

示例3: predict

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def predict(self, data: data.dataset, **kwargs):

        predict_data = DataLoader(data, batch_size=data.batch_size, shuffle=False)
        for batch_id, (feature, label) in enumerate(predict_data):
            feature = torch.tensor(feature, dtype=torch.float32)
            # label = torch.tensor(label, dtype=torch.float32)
            y = self._model(feature)
            if batch_id == 0:
                result = y.detach().numpy()
            else:
                result = np.vstack((result, y.detach().numpy()))
        return result 
开发者ID:FederatedAI,项目名称:FATE,代码行数:14,代码来源:nn_model.py

示例4: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def __init__(self):
        self.img_blur_path_template = cfg.DIR.IMAGE_BLUR_PATH
        self.img_clear_path_template = cfg.DIR.IMAGE_CLEAR_PATH

        # Load all files of the dataset
        with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
            self.files_list = json.loads(file.read()) 
开发者ID:sczhou,项目名称:STFAN,代码行数:9,代码来源:data_loaders.py

示例5: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def __init__(self):
        self.img_left_path_template = cfg.DIR.IMAGE_LEFT_PATH
        self.img_right_path_template = cfg.DIR.IMAGE_RIGHT_PATH
        self.disp_left_path_template = cfg.DIR.DISPARITY_LEFT_PATH
        self.disp_right_path_template = cfg.DIR.DISPARITY_RIGHT_PATH
        # Load all files of the dataset
        with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
            self.files_list = json.loads(file.read()) 
开发者ID:sczhou,项目名称:DAVANet,代码行数:10,代码来源:data_loaders.py

示例6: val

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def val(model,dataset):
    '''
    计算模型在验证集上的分数
    '''

    dataset.train(False)
    model.eval()

    dataloader = data.DataLoader(dataset,
                    batch_size = opt.batch_size,
                    shuffle = False,
                    num_workers = opt.num_workers,
                    pin_memory = True
                    )
    
    predict_label_and_marked_label_list=[]
    for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)):
        title,content,label = Variable(title.cuda(),volatile=True),\
                              Variable(content.cuda(),volatile=True),\
                              Variable(label.cuda(),volatile=True)
        score = model(title,content)
        # !TODO: 优化此处代码
        #       1. append
        #       2. for循环
        #       3. topk 代替sort

        predict = score.data.topk(5,dim=1)[1].cpu().tolist()
        true_target = label.data.float().topk(5,dim=1)
        true_index=true_target[1][:,:5]
        true_label=true_target[0][:,:5]
        tmp= []

        for jj in range(label.size(0)):
            true_index_=true_index[jj]
            true_label_=true_label[jj]
            true=true_index_[true_label_>0]
            tmp.append((predict[jj],true.tolist()))
        
        predict_label_and_marked_label_list.extend(tmp)
    del score

    dataset.train(True)
    model.train()
    
    scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list)
    return (scores,prec_,recall_,_ss) 
开发者ID:chenyuntc,项目名称:PyTorchText,代码行数:48,代码来源:main.py

示例7: get_dataset

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def get_dataset(self, dataset_type, transforms=None):
        sequences = []
        # Load data for each sequence
        for file in self.files_list:
            if dataset_type == DatasetType.TRAIN and file['phase'] == 'train':
                name = file['name']
                phase = file['phase']
                samples = file['sample']
                sam_len = len(samples)
                seq_len = cfg.DATA.SEQ_LENGTH
                seq_num = int(sam_len/seq_len)
                for n in range(seq_num):
                    sequence = self.get_files_of_taxonomy(phase, name, samples[seq_len*n: seq_len*(n+1)])
                    sequences.extend(sequence)

                if not seq_len%seq_len == 0:
                    sequence = self.get_files_of_taxonomy(phase, name, samples[-seq_len:])
                    sequences.extend(sequence)
                    seq_num += 1

                print('[INFO] %s Collecting files of Taxonomy [Name = %s]' % (dt.now(), name + ': ' + str(seq_num)))


            elif dataset_type == DatasetType.TEST and file['phase'] == 'test':
                name = file['name']
                phase = file['phase']
                samples = file['sample']
                sam_len = len(samples)
                seq_len = cfg.DATA.SEQ_LENGTH
                seq_num = int(sam_len / seq_len)
                for n in range(seq_num):
                    sequence = self.get_files_of_taxonomy(phase, name, samples[seq_len*n: seq_len*(n+1)])
                    sequences.extend(sequence)

                if not seq_len % seq_len == 0:
                    sequence = self.get_files_of_taxonomy(phase, name, samples[-seq_len:])
                    sequences.extend(sequence)
                    seq_num += 1

                print('[INFO] %s Collecting files of Taxonomy [Name = %s]' % (dt.now(), name + ': ' + str(seq_num)))

        print('[INFO] %s Complete collecting files of the dataset for %s. Seq Number: %d.\n' % (dt.now(), dataset_type.name, len(sequences)))
        return VideoDeblurDataset(sequences, transforms) 
开发者ID:sczhou,项目名称:STFAN,代码行数:45,代码来源:data_loaders.py


注:本文中的torch.utils.data.dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。