本文整理匯總了Python中torch.utils.data.dataset方法的典型用法代碼示例。如果您正苦於以下問題:Python data.dataset方法的具體用法?Python data.dataset怎麽用?Python data.dataset使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.utils.data
的用法示例。
在下文中一共展示了data.dataset方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: get_dataset
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def get_dataset(self, dataset_type, transforms=None):
files = []
# Load data for each sequence
for file in self.files_list:
if dataset_type == DatasetType.TRAIN and file['phase'] == 'Train':
name = file['name']
pair_num = file['pair_num']
samples = file['sample']
files_num_old = len(files)
files.extend(self.get_files_of_taxonomy(name, samples))
print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
elif dataset_type == DatasetType.TEST and file['phase'] == 'Test':
name = file['name']
pair_num = file['pair_num']
samples = file['sample']
files_num_old = len(files)
files.extend(self.get_files_of_taxonomy(name, samples))
print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
print('[INFO] %s Complete collecting files of the dataset for %s. Total Pair Numbur: %d.\n' % (dt.now(), dataset_type.name, len(files)))
return StereoDeblurDataset(files, transforms)
示例2: val
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def val(model,dataset):
dataset.train(False)
model.eval()
dataloader = data.DataLoader(dataset,
batch_size = opt.batch_size,
shuffle = False,
num_workers = opt.num_workers,
pin_memory = True
)
predict_label_and_marked_label_list=[]
for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)):
title,content,label = (Variable(title[0].cuda()),Variable(title[1].cuda())),(Variable(content[0].cuda()),Variable(content[1].cuda())),Variable(label.cuda())
score = model(title,content)
# !TODO: 優化此處代碼
# 1. append
# 2. for循環
# 3. topk 代替sort
predict = score.data.topk(5,dim=1)[1].cpu().tolist()
true_target = label.data.float().topk(5,dim=1)#[1].cpu().tolist()#sort(dim=1,descending=True)
true_index=true_target[1][:,:5]
true_label=true_target[0][:,:5]
tmp= []
for jj in range(label.size(0)):
true_index_=true_index[jj]
true_label_=true_label[jj]
true=true_index_[true_label_>0]
tmp.append((predict[jj],true.tolist()))
predict_label_and_marked_label_list.extend(tmp)
del score
dataset.train(True)
model.train()
scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list)
return (scores,prec_,recall_,_ss)
示例3: predict
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def predict(self, data: data.dataset, **kwargs):
predict_data = DataLoader(data, batch_size=data.batch_size, shuffle=False)
for batch_id, (feature, label) in enumerate(predict_data):
feature = torch.tensor(feature, dtype=torch.float32)
# label = torch.tensor(label, dtype=torch.float32)
y = self._model(feature)
if batch_id == 0:
result = y.detach().numpy()
else:
result = np.vstack((result, y.detach().numpy()))
return result
示例4: __init__
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def __init__(self):
self.img_blur_path_template = cfg.DIR.IMAGE_BLUR_PATH
self.img_clear_path_template = cfg.DIR.IMAGE_CLEAR_PATH
# Load all files of the dataset
with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
self.files_list = json.loads(file.read())
示例5: __init__
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def __init__(self):
self.img_left_path_template = cfg.DIR.IMAGE_LEFT_PATH
self.img_right_path_template = cfg.DIR.IMAGE_RIGHT_PATH
self.disp_left_path_template = cfg.DIR.DISPARITY_LEFT_PATH
self.disp_right_path_template = cfg.DIR.DISPARITY_RIGHT_PATH
# Load all files of the dataset
with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
self.files_list = json.loads(file.read())
示例6: val
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def val(model,dataset):
'''
計算模型在驗證集上的分數
'''
dataset.train(False)
model.eval()
dataloader = data.DataLoader(dataset,
batch_size = opt.batch_size,
shuffle = False,
num_workers = opt.num_workers,
pin_memory = True
)
predict_label_and_marked_label_list=[]
for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)):
title,content,label = Variable(title.cuda(),volatile=True),\
Variable(content.cuda(),volatile=True),\
Variable(label.cuda(),volatile=True)
score = model(title,content)
# !TODO: 優化此處代碼
# 1. append
# 2. for循環
# 3. topk 代替sort
predict = score.data.topk(5,dim=1)[1].cpu().tolist()
true_target = label.data.float().topk(5,dim=1)
true_index=true_target[1][:,:5]
true_label=true_target[0][:,:5]
tmp= []
for jj in range(label.size(0)):
true_index_=true_index[jj]
true_label_=true_label[jj]
true=true_index_[true_label_>0]
tmp.append((predict[jj],true.tolist()))
predict_label_and_marked_label_list.extend(tmp)
del score
dataset.train(True)
model.train()
scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list)
return (scores,prec_,recall_,_ss)
示例7: get_dataset
# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import dataset [as 別名]
def get_dataset(self, dataset_type, transforms=None):
sequences = []
# Load data for each sequence
for file in self.files_list:
if dataset_type == DatasetType.TRAIN and file['phase'] == 'train':
name = file['name']
phase = file['phase']
samples = file['sample']
sam_len = len(samples)
seq_len = cfg.DATA.SEQ_LENGTH
seq_num = int(sam_len/seq_len)
for n in range(seq_num):
sequence = self.get_files_of_taxonomy(phase, name, samples[seq_len*n: seq_len*(n+1)])
sequences.extend(sequence)
if not seq_len%seq_len == 0:
sequence = self.get_files_of_taxonomy(phase, name, samples[-seq_len:])
sequences.extend(sequence)
seq_num += 1
print('[INFO] %s Collecting files of Taxonomy [Name = %s]' % (dt.now(), name + ': ' + str(seq_num)))
elif dataset_type == DatasetType.TEST and file['phase'] == 'test':
name = file['name']
phase = file['phase']
samples = file['sample']
sam_len = len(samples)
seq_len = cfg.DATA.SEQ_LENGTH
seq_num = int(sam_len / seq_len)
for n in range(seq_num):
sequence = self.get_files_of_taxonomy(phase, name, samples[seq_len*n: seq_len*(n+1)])
sequences.extend(sequence)
if not seq_len % seq_len == 0:
sequence = self.get_files_of_taxonomy(phase, name, samples[-seq_len:])
sequences.extend(sequence)
seq_num += 1
print('[INFO] %s Collecting files of Taxonomy [Name = %s]' % (dt.now(), name + ': ' + str(seq_num)))
print('[INFO] %s Complete collecting files of the dataset for %s. Seq Number: %d.\n' % (dt.now(), dataset_type.name, len(sequences)))
return VideoDeblurDataset(sequences, transforms)