本文整理汇总了Python中torch.utils.data.dataset.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python dataset.Dataset方法的具体用法?Python dataset.Dataset怎么用?Python dataset.Dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.dataset
的用法示例。
在下文中一共展示了dataset.Dataset方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: one_batch_dataset
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def one_batch_dataset(dataset, batch_size):
print("==> Grabbing a single batch")
perm = torch.randperm(len(dataset))
one_batch = [dataset[idx.item()] for idx in perm[:batch_size]]
class _OneBatchWrapper(Dataset):
def __init__(self):
self.batch = one_batch
def __getitem__(self, index):
return self.batch[index]
def __len__(self):
return len(self.batch)
return _OneBatchWrapper()
示例2: __init__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(
self,
video_list: List[str],
frame_selector: Optional[FrameSelector] = None,
transform: Optional[FrameTransform] = None,
):
"""
Dataset constructor
Args:
video_list (List[str]): list of paths to video files
frame_selector (Callable: KeyFrameList -> KeyFrameList):
selects keyframes to process, keyframes are given by
packet timestamps in timebase counts. If None, all keyframes
are selected (default: None)
transform (Callable: torch.Tensor -> torch.Tensor):
transforms a batch of RGB images (tensors of size [B, H, W, 3]),
returns a tensor of the same size. If None, no transform is
applied (default: None)
"""
self.video_list = video_list
self.frame_selector = frame_selector
self.transform = transform
示例3: __init__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
示例4: __init__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, dataset_name, data_path, train=True):
self.dataset_name = dataset_name
if self.dataset_name not in {'matterport'}:
raise Exception(f'Dataset name not found: {self.dataset_name}')
self.data_root = data_path
self.len = 0
self.train = train
self.scene_name = []
self.color_name = []
self.depth_name = []
self.normal_name = []
self.render_name = []
self.boundary_name = []
self.depth_boundary_name = []
if self.dataset_name == 'matterport':
self._load_data_name_matterport(train=self.train)
示例5: __init__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, owner_dataset, indices=None, filter_name=None, filter_func=None):
"""
Args:
owner_dataset (Dataset): the original dataset.
indices (List[int]): a list of indices that was filterred out.
filter_name (str): human-friendly name for the filter.
filter_func (Callable): just for tracking.
"""
super().__init__()
self.owner_dataset = owner_dataset
self.indices = indices
self._filter_name = filter_name
self._filter_func = filter_func
示例6: __init__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, path, templates, img_transforms=None, dataset_root="", split="train",
train=True, input_size=(500, 500), heatmap_size=(63, 63),
pos_thresh=0.7, neg_thresh=0.3, pos_fraction=0.5, debug=False):
super().__init__()
self.data = []
self.split = split
self.load(path)
print("Dataset loaded")
print("{0} samples in the {1} dataset".format(len(self.data),
self.split))
# self.data = data
# canonical object templates obtained via clustering
# NOTE we directly use the values from Peiyun's repository stored in "templates.json"
self.templates = templates
self.transforms = img_transforms
self.dataset_root = Path(dataset_root)
self.input_size = input_size
self.heatmap_size = heatmap_size
self.pos_thresh = pos_thresh
self.neg_thresh = neg_thresh
self.pos_fraction = pos_fraction
# receptive field computed using a combination of values from Matconvnet
# plus derived equations.
self.rf = {
'size': [859, 859],
'stride': [8, 8],
'offset': [-1, -1]
}
self.processor = DataProcessor(input_size, heatmap_size,
pos_thresh, neg_thresh,
templates, rf=self.rf)
self.debug = debug
示例7: __len__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __len__(self):
return self.count
##
# Dataset loading (for training)
##
# Operator to load hdf5-file for training
示例8: collate_task
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def collate_task(self, task):
if isinstance(task, TorchDataset):
return self.collate_fn([task[idx] for idx in range(len(task))])
elif isinstance(task, OrderedDict):
return OrderedDict([(key, self.collate_task(subtask))
for (key, subtask) in task.items()])
else:
raise NotImplementedError()
示例9: main
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def main():
# Load the train-clean-100 set
tables = pd.read_csv(os.path.join(root, 'train-clean-100' + '.csv'))
# Compute speaker dictionary
print('[Dataset] - Computing speaker class...')
O = tables['file_path'].tolist()
speakers = get_all_speakers(O)
speaker2idx = compute_speaker2idx(speakers)
class_num = len(speaker2idx)
print('[Dataset] - Possible speaker classes: ', class_num)
train = tables.sample(frac=0.9, random_state=20190929) # random state is a seed value
test = tables.drop(train.index)
table = train.sort_values(by=['length'], ascending=False)
X = table['file_path'].tolist()
X_lens = table['length'].tolist()
# Crop seqs that are too long
if drop and max_timestep > 0:
table = table[table.length < max_timestep]
if drop and max_label_len > 0:
table = table[table.label.str.count('_')+1 < max_label_len]
# computer utterance per speaker
num_utt = []
for speaker in speakers:
if speaker in speaker2idx:
num_utt.append(speakers[speaker])
print('Average utterance per speaker: ', np.mean(num_utt))
# TODO: furthur analysis
开发者ID:andi611,项目名称:Self-Supervised-Speech-Pretraining-and-Representation-Learning,代码行数:37,代码来源:observe_speaker.py
示例10: __init__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, run_mam, file_path, sets, bucket_size, max_timestep=0, drop=False, mam_config=None):
super(AcousticDataset, self).__init__(file_path, sets, bucket_size, max_timestep, drop)
self.run_mam = run_mam
self.mam_config = mam_config
self.sample_step = mam_config['max_input_length'] if 'max_input_length' in mam_config else 0
if self.sample_step > 0: print('[Dataset] - Sampling random segments for training, sample length:', self.sample_step)
X = self.table['file_path'].tolist()
X_lens = self.table['length'].tolist()
# Use bucketing to allow different batch size at run time
self.X = []
batch_x, batch_len = [], []
for x, x_len in zip(X, X_lens):
batch_x.append(x)
batch_len.append(x_len)
# Fill in batch_x until batch is full
if len(batch_x) == bucket_size:
# Half the batch size if seq too long
if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
self.X.append(batch_x[:bucket_size//2])
self.X.append(batch_x[bucket_size//2:])
else:
self.X.append(batch_x)
batch_x, batch_len = [], []
# Gather the last batch
if len(batch_x) > 0:
self.X.append(batch_x)
开发者ID:andi611,项目名称:Self-Supervised-Speech-Pretraining-and-Representation-Learning,代码行数:33,代码来源:dataloader.py
示例11: __getitem__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __getitem__(self, index):
datum = self.data[index]
image_root = self.dataset_root / "WIDER_{0}".format(self.split)
image_path = image_root / "images" / datum['img_path']
image = Image.open(image_path).convert('RGB')
if self.split == 'train':
bboxes = datum['bboxes']
if self.debug:
if bboxes.shape[0] == 0:
print(image_path)
print("Dataset index: \t", index)
print("image path:\t", image_path)
img, class_map, reg_map, bboxes = self.process_inputs(image,
bboxes)
# convert everything to tensors
if self.transforms is not None:
# if img is a byte or uint8 array, it will convert from 0-255 to 0-1
# this converts from (HxWxC) to (CxHxW) as well
img = self.transforms(img)
class_map = torch.from_numpy(class_map)
reg_map = torch.from_numpy(reg_map)
return img, class_map, reg_map
elif self.split == 'val':
# NOTE Return only the image and the image path.
# Use the eval_tools to get the final results.
if self.transforms is not None:
# Only convert to tensor since we do normalization after rescaling
img = transforms.functional.to_tensor(image)
return img, datum['img_path']
elif self.split == 'test':
filename = datum['img_path']
if self.transforms is not None:
img = self.transforms(image)
return img, filename