本文整理汇总了Python中torch.utils.data.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.Dataset方法的具体用法?Python data.Dataset怎么用?Python data.Dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.Dataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_test_loader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
workers, opt, cap_suffix='caps'):
dpath = os.path.join(opt.data_path, data_name)
if opt.data_name.endswith('_precomp'):
if not opt.use_external_captions:
test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
batch_size, False, workers, cap_suffix)
else:
test_loader = get_precomp_train_caption_loader(dpath, split_name, vocab, opt,
batch_size, False, workers, cap_suffix)
else:
# Build Dataset Loader
roots, ids = get_paths(dpath, data_name, opt.use_restval)
transform = get_transform(data_name, split_name, opt)
test_loader = get_loader_single(opt.data_name, split_name,
roots[split_name]['img'],
roots[split_name]['cap'],
vocab, transform, ids=ids[split_name],
batch_size=batch_size, shuffle=False,
num_workers=workers,
collate_fn=collate_fn)
return test_loader
示例2: get_seq
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def get_seq(pairs,lang,batch_size,type,max_len):
x_seq = []
y_seq = []
ptr_seq = []
for pair in pairs:
x_seq.append(pair[0])
y_seq.append(pair[1])
ptr_seq.append(pair[2])
if(type):
lang.index_words(pair[0])
lang.index_words(pair[1])
dataset = Dataset(x_seq, y_seq,ptr_seq,lang.word2index, lang.word2index,max_len)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=type,
collate_fn=collate_fn)
return data_loader
示例3: get_seq
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def get_seq(pairs,lang,batch_size,type,max_len):
x_seq = []
y_seq = []
ptr_seq = []
gate_seq = []
for pair in pairs:
x_seq.append(pair[0])
y_seq.append(pair[1])
ptr_seq.append(pair[2])
gate_seq.append(pair[3])
if(type):
lang.index_words(pair[0])
lang.index_words(pair[1])
dataset = Dataset(x_seq, y_seq,ptr_seq,gate_seq,lang.word2index, lang.word2index,max_len)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=type,
collate_fn=collate_fn)
return data_loader
示例4: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, data_path=None, split_set='train', num_points=20000,
use_color=False, use_height=False, use_v1=False,
augment=False, scan_idx_list=None):
assert(num_points<=50000)
self.use_v1 = use_v1
if use_v1:
self.data_path = os.path.join(data_path, 'sunrgbd_pc_bbox_votes_50k_v1_' + split_set)
# self.data_path = os.path.join('/scratch/cluster/yanght/Dataset/sunrgbd/sunrgbd_pc_bbox_votes_50k_v1_' + split_set)
else:
AssertionError("v2 data is not prepared")
self.raw_data_path = os.path.join(ROOT_DIR, 'sunrgbd/sunrgbd_trainval')
self.scan_names = sorted(list(set([os.path.basename(x)[0:6] \
for x in os.listdir(self.data_path)])))
if scan_idx_list is not None:
self.scan_names = [self.scan_names[i] for i in scan_idx_list]
self.num_points = num_points
self.augment = augment
self.use_color = use_color
self.use_height = use_height
示例5: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, num_epochs=None, *args, **kwargs):
"""Constructor.
Args:
dataset: A `Dataset` object to be loaded.
batch_size: int, the size of each batch.
shuffle: bool, whether to shuffle the dataset after each epoch.
drop_last: bool, whether to drop last batch if its size is less than
`batch_size`.
num_epochs: int or None, number of epochs to iterate over the dataset.
If None, defaults to infinity.
"""
super().__init__(
*args, **kwargs
)
self.finite_iterable = super().__iter__()
self.counter = 0
self.num_epochs = float('inf') if num_epochs is None else num_epochs
示例6: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, root_dir, split_name, transform):
super(ISSDataset, self).__init__()
self.root_dir = root_dir
self.split_name = split_name
self.transform = transform
# Folders
self._img_dir = path.join(root_dir, ISSDataset._IMG_DIR)
self._msk_dir = path.join(root_dir, ISSDataset._MSK_DIR)
self._lst_dir = path.join(root_dir, ISSDataset._LST_DIR)
for d in self._img_dir, self._msk_dir, self._lst_dir:
if not path.isdir(d):
raise IOError("Dataset sub-folder {} does not exist".format(d))
# Load meta-data and split
self._meta, self._images = self._load_split()
示例7: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, seed=0):
super(MNIST, self).__init__(root)
self.seed = seed
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
示例8: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
示例9: __getitem__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __getitem__(self, idx):
record = OrderedDict()
for feat in chain(
self.features.number_features,
self.features.category_features):
record[feat.name] = self.X_map[feat.name][idx]
for feat in self.features.sequence_features:
seq = self.X_map[feat.name][idx]
record[feat.name] = Dataset.__pad_sequence(feat, seq)
record[f"__{feat.name}_length"] = np.int64(seq.shape[0])
if self.y is not None:
record['label'] = self.y[idx]
return record
示例10: transform
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
Parameters
----------
fn : callable
A transformer function that takes a sample as input and
returns the transformed sample.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
trans = _LazyTransformDataset(self, fn)
if lazy:
return trans
return SimpleDataset([i for i in trans])
示例11: _dataset_from_chunk
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def _dataset_from_chunk(cls, chunk, processor):
"""
Creating a dataset for a chunk (= subset) of dicts. In multiprocessing:
* we read in all dicts from a file
* split all dicts into chunks
* feed *one chunk* to *one process*
=> the *one chunk* gets converted to *one dataset* (that's what we do here)
* all datasets get collected and concatenated
:param chunk: Instead of only having a list of dicts here we also supply an index (ascending int) for each.
=> [(0, dict), (1, dict) ...]
:type chunk: list of tuples
:param processor: FARM Processor (e.g. TextClassificationProcessor)
:return: PyTorch Dataset
"""
dicts = [d[1] for d in chunk]
indices = [x[0] for x in chunk]
dataset = processor.dataset_from_dicts(dicts=dicts, indices=indices)
return dataset
示例12: covert_dataset_to_dataloader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def covert_dataset_to_dataloader(dataset, sampler, batch_size):
"""
Wraps a PyTorch Dataset with a DataLoader.
:param dataset: Dataset to be wrapped.
:type dataset: Dataset
:param sampler: PyTorch sampler used to pick samples in a batch.
:type sampler: Sampler
:param batch_size: Number of samples in the batch.
:return: A DataLoader that wraps the input Dataset.
"""
sampler_initialized = sampler(dataset)
data_loader = DataLoader(
dataset, sampler=sampler_initialized, batch_size=batch_size
)
return data_loader
示例13: resume_training
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def resume_training(self, train_data, model_path, valid_data=None):
"""This model resume training of a classifier by reloading the appropriate state_dicts for each model
Args:
train_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
X (data) and Y (labels) for the train split
model_path: the path to the saved checpoint for resuming training
valid_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
X (data) and Y (labels) for the dev split
"""
restore_state = self.checkpointer.restore(model_path)
loss_fn = self._get_loss_fn()
self.train()
self._train_model(
train_data=train_data,
loss_fn=loss_fn,
valid_data=valid_data,
restore_state=restore_state,
)
示例14: _create_data_loader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def _create_data_loader(self, data, **kwargs):
"""Converts input data into a DataLoader"""
if data is None:
return None
# Set DataLoader config
# NOTE: Not applicable if data is already a DataLoader
config = {
**self.config["train_config"]["data_loader_config"],
**kwargs,
"pin_memory": self.config["device"] != "cpu",
}
# Return data as DataLoader
if isinstance(data, DataLoader):
return data
elif isinstance(data, Dataset):
return DataLoader(data, **config)
elif isinstance(data, (tuple, list)):
return DataLoader(self._create_dataset(*data), **config)
else:
raise ValueError("Input data type not recognized.")
示例15: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
"""
Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
a comma. Each new line is a different sample. Example below:
/path/to/audio.wav,/path/to/audio.txt
...
:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
:param manifest_filepath: Path to manifest csv as describe above
:param labels: String containing all the possible characters to map to
:param normalize: Apply standard mean and deviation normalization to audio tensor
:param augment(default False): Apply random tempo and gain perturbations
"""
with open(manifest_filepath) as f:
ids = f.readlines()
ids = [x.strip().split(',') for x in ids]
self.ids = ids
self.size = len(ids)
self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment)