本文整理汇总了Python中torch.utils.data.dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.dataset方法的具体用法?Python data.dataset怎么用?Python data.dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.dataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_dataset
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def get_dataset(self, dataset_type, transforms=None):
files = []
# Load data for each sequence
for file in self.files_list:
if dataset_type == DatasetType.TRAIN and file['phase'] == 'Train':
name = file['name']
pair_num = file['pair_num']
samples = file['sample']
files_num_old = len(files)
files.extend(self.get_files_of_taxonomy(name, samples))
print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
elif dataset_type == DatasetType.TEST and file['phase'] == 'Test':
name = file['name']
pair_num = file['pair_num']
samples = file['sample']
files_num_old = len(files)
files.extend(self.get_files_of_taxonomy(name, samples))
print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
print('[INFO] %s Complete collecting files of the dataset for %s. Total Pair Numbur: %d.\n' % (dt.now(), dataset_type.name, len(files)))
return StereoDeblurDataset(files, transforms)
示例2: val
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def val(model,dataset):
dataset.train(False)
model.eval()
dataloader = data.DataLoader(dataset,
batch_size = opt.batch_size,
shuffle = False,
num_workers = opt.num_workers,
pin_memory = True
)
predict_label_and_marked_label_list=[]
for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)):
title,content,label = (Variable(title[0].cuda()),Variable(title[1].cuda())),(Variable(content[0].cuda()),Variable(content[1].cuda())),Variable(label.cuda())
score = model(title,content)
# !TODO: 优化此处代码
# 1. append
# 2. for循环
# 3. topk 代替sort
predict = score.data.topk(5,dim=1)[1].cpu().tolist()
true_target = label.data.float().topk(5,dim=1)#[1].cpu().tolist()#sort(dim=1,descending=True)
true_index=true_target[1][:,:5]
true_label=true_target[0][:,:5]
tmp= []
for jj in range(label.size(0)):
true_index_=true_index[jj]
true_label_=true_label[jj]
true=true_index_[true_label_>0]
tmp.append((predict[jj],true.tolist()))
predict_label_and_marked_label_list.extend(tmp)
del score
dataset.train(True)
model.train()
scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list)
return (scores,prec_,recall_,_ss)
示例3: predict
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def predict(self, data: data.dataset, **kwargs):
predict_data = DataLoader(data, batch_size=data.batch_size, shuffle=False)
for batch_id, (feature, label) in enumerate(predict_data):
feature = torch.tensor(feature, dtype=torch.float32)
# label = torch.tensor(label, dtype=torch.float32)
y = self._model(feature)
if batch_id == 0:
result = y.detach().numpy()
else:
result = np.vstack((result, y.detach().numpy()))
return result
示例4: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def __init__(self):
self.img_blur_path_template = cfg.DIR.IMAGE_BLUR_PATH
self.img_clear_path_template = cfg.DIR.IMAGE_CLEAR_PATH
# Load all files of the dataset
with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
self.files_list = json.loads(file.read())
示例5: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def __init__(self):
self.img_left_path_template = cfg.DIR.IMAGE_LEFT_PATH
self.img_right_path_template = cfg.DIR.IMAGE_RIGHT_PATH
self.disp_left_path_template = cfg.DIR.DISPARITY_LEFT_PATH
self.disp_right_path_template = cfg.DIR.DISPARITY_RIGHT_PATH
# Load all files of the dataset
with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
self.files_list = json.loads(file.read())
示例6: val
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def val(model,dataset):
'''
计算模型在验证集上的分数
'''
dataset.train(False)
model.eval()
dataloader = data.DataLoader(dataset,
batch_size = opt.batch_size,
shuffle = False,
num_workers = opt.num_workers,
pin_memory = True
)
predict_label_and_marked_label_list=[]
for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)):
title,content,label = Variable(title.cuda(),volatile=True),\
Variable(content.cuda(),volatile=True),\
Variable(label.cuda(),volatile=True)
score = model(title,content)
# !TODO: 优化此处代码
# 1. append
# 2. for循环
# 3. topk 代替sort
predict = score.data.topk(5,dim=1)[1].cpu().tolist()
true_target = label.data.float().topk(5,dim=1)
true_index=true_target[1][:,:5]
true_label=true_target[0][:,:5]
tmp= []
for jj in range(label.size(0)):
true_index_=true_index[jj]
true_label_=true_label[jj]
true=true_index_[true_label_>0]
tmp.append((predict[jj],true.tolist()))
predict_label_and_marked_label_list.extend(tmp)
del score
dataset.train(True)
model.train()
scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list)
return (scores,prec_,recall_,_ss)
示例7: get_dataset
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import dataset [as 别名]
def get_dataset(self, dataset_type, transforms=None):
sequences = []
# Load data for each sequence
for file in self.files_list:
if dataset_type == DatasetType.TRAIN and file['phase'] == 'train':
name = file['name']
phase = file['phase']
samples = file['sample']
sam_len = len(samples)
seq_len = cfg.DATA.SEQ_LENGTH
seq_num = int(sam_len/seq_len)
for n in range(seq_num):
sequence = self.get_files_of_taxonomy(phase, name, samples[seq_len*n: seq_len*(n+1)])
sequences.extend(sequence)
if not seq_len%seq_len == 0:
sequence = self.get_files_of_taxonomy(phase, name, samples[-seq_len:])
sequences.extend(sequence)
seq_num += 1
print('[INFO] %s Collecting files of Taxonomy [Name = %s]' % (dt.now(), name + ': ' + str(seq_num)))
elif dataset_type == DatasetType.TEST and file['phase'] == 'test':
name = file['name']
phase = file['phase']
samples = file['sample']
sam_len = len(samples)
seq_len = cfg.DATA.SEQ_LENGTH
seq_num = int(sam_len / seq_len)
for n in range(seq_num):
sequence = self.get_files_of_taxonomy(phase, name, samples[seq_len*n: seq_len*(n+1)])
sequences.extend(sequence)
if not seq_len % seq_len == 0:
sequence = self.get_files_of_taxonomy(phase, name, samples[-seq_len:])
sequences.extend(sequence)
seq_num += 1
print('[INFO] %s Collecting files of Taxonomy [Name = %s]' % (dt.now(), name + ': ' + str(seq_num)))
print('[INFO] %s Complete collecting files of the dataset for %s. Seq Number: %d.\n' % (dt.now(), dataset_type.name, len(sequences)))
return VideoDeblurDataset(sequences, transforms)