当前位置: 首页>>代码示例>>Python>>正文


Python dataset.Dataset方法代码示例

本文整理汇总了Python中torch.utils.data.dataset.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python dataset.Dataset方法的具体用法?Python dataset.Dataset怎么用?Python dataset.Dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.dataset的用法示例。


在下文中一共展示了dataset.Dataset方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: one_batch_dataset

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def one_batch_dataset(dataset, batch_size):
    print("==> Grabbing a single batch")

    perm = torch.randperm(len(dataset))

    one_batch = [dataset[idx.item()] for idx in perm[:batch_size]]

    class _OneBatchWrapper(Dataset):
        def __init__(self):
            self.batch = one_batch

        def __getitem__(self, index):
            return self.batch[index]

        def __len__(self):
            return len(self.batch)

    return _OneBatchWrapper() 
开发者ID:allenai,项目名称:hidden-networks,代码行数:20,代码来源:utils.py

示例2: __init__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(
        self,
        video_list: List[str],
        frame_selector: Optional[FrameSelector] = None,
        transform: Optional[FrameTransform] = None,
    ):
        """
        Dataset constructor

        Args:
            video_list (List[str]): list of paths to video files
            frame_selector (Callable: KeyFrameList -> KeyFrameList):
                selects keyframes to process, keyframes are given by
                packet timestamps in timebase counts. If None, all keyframes
                are selected (default: None)
            transform (Callable: torch.Tensor -> torch.Tensor):
                transforms a batch of RGB images (tensors of size [B, H, W, 3]),
                returns a tensor of the same size. If None, no transform is
                applied (default: None)

        """
        self.video_list = video_list
        self.frame_selector = frame_selector
        self.transform = transform 
开发者ID:facebookresearch,项目名称:detectron2,代码行数:26,代码来源:video_keyframe_dataset.py

示例3: __init__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root             = os.path.expanduser(root)
        self.transform        = transform
        self.target_transform = target_transform
        self.train            = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)) 
开发者ID:mhw32,项目名称:multimodal-vae-public,代码行数:21,代码来源:datasets.py

示例4: __init__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, dataset_name, data_path, train=True):
        self.dataset_name = dataset_name
        if self.dataset_name not in {'matterport'}:
            raise Exception(f'Dataset name not found: {self.dataset_name}')
        self.data_root = data_path
        self.len = 0
        self.train = train
        self.scene_name = []
        self.color_name = []
        self.depth_name = []
        self.normal_name = []
        self.render_name = []
        self.boundary_name = []
        self.depth_boundary_name = []
        
        if self.dataset_name == 'matterport':
            self._load_data_name_matterport(train=self.train) 
开发者ID:patrickwu2,项目名称:Depth-Completion,代码行数:19,代码来源:data_loader.py

示例5: __init__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, owner_dataset, indices=None, filter_name=None, filter_func=None):
        """
        Args:
            owner_dataset (Dataset): the original dataset.
            indices (List[int]): a list of indices that was filterred out.
            filter_name (str): human-friendly name for the filter.
            filter_func (Callable): just for tracking.
        """

        super().__init__()
        self.owner_dataset = owner_dataset
        self.indices = indices
        self._filter_name = filter_name
        self._filter_func = filter_func 
开发者ID:vacancy,项目名称:NSCL-PyTorch-Release,代码行数:16,代码来源:filterable.py

示例6: __init__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, path, templates, img_transforms=None, dataset_root="", split="train",
                 train=True, input_size=(500, 500), heatmap_size=(63, 63),
                 pos_thresh=0.7, neg_thresh=0.3, pos_fraction=0.5, debug=False):
        super().__init__()

        self.data = []
        self.split = split

        self.load(path)

        print("Dataset loaded")
        print("{0} samples in the {1} dataset".format(len(self.data),
                                                      self.split))
        # self.data = data

        # canonical object templates obtained via clustering
        # NOTE we directly use the values from Peiyun's repository stored in "templates.json"
        self.templates = templates

        self.transforms = img_transforms
        self.dataset_root = Path(dataset_root)
        self.input_size = input_size
        self.heatmap_size = heatmap_size
        self.pos_thresh = pos_thresh
        self.neg_thresh = neg_thresh
        self.pos_fraction = pos_fraction

        # receptive field computed using a combination of values from Matconvnet
        # plus derived equations.
        self.rf = {
            'size': [859, 859],
            'stride': [8, 8],
            'offset': [-1, -1]
        }

        self.processor = DataProcessor(input_size, heatmap_size,
                                       pos_thresh, neg_thresh,
                                       templates, rf=self.rf)
        self.debug = debug 
开发者ID:varunagrawal,项目名称:tiny-faces-pytorch,代码行数:41,代码来源:wider_face.py

示例7: __len__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __len__(self):
        return self.count


##
# Dataset loading (for training)
##

# Operator to load hdf5-file for training 
开发者ID:Deep-MI,项目名称:FastSurfer,代码行数:11,代码来源:load_neuroimaging_data.py

示例8: collate_task

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def collate_task(self, task):
        if isinstance(task, TorchDataset):
            return self.collate_fn([task[idx] for idx in range(len(task))])
        elif isinstance(task, OrderedDict):
            return OrderedDict([(key, self.collate_task(subtask))
                for (key, subtask) in task.items()])
        else:
            raise NotImplementedError() 
开发者ID:tristandeleu,项目名称:pytorch-meta,代码行数:10,代码来源:dataloader.py

示例9: main

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def main():

    # Load the train-clean-100 set
    tables = pd.read_csv(os.path.join(root, 'train-clean-100' + '.csv'))

    # Compute speaker dictionary
    print('[Dataset] - Computing speaker class...')
    O = tables['file_path'].tolist()
    speakers = get_all_speakers(O)
    speaker2idx = compute_speaker2idx(speakers)
    class_num = len(speaker2idx)
    print('[Dataset] - Possible speaker classes: ', class_num)
    

    train = tables.sample(frac=0.9, random_state=20190929) # random state is a seed value
    test = tables.drop(train.index)
    table = train.sort_values(by=['length'], ascending=False)

    X = table['file_path'].tolist()
    X_lens = table['length'].tolist()

    # Crop seqs that are too long
    if drop and max_timestep > 0:
        table = table[table.length < max_timestep]
    if drop and max_label_len > 0:
        table = table[table.label.str.count('_')+1 < max_label_len]

    # computer utterance per speaker
    num_utt = []
    for speaker in speakers:
        if speaker in speaker2idx:
            num_utt.append(speakers[speaker])
    print('Average utterance per speaker: ', np.mean(num_utt))

    # TODO: furthur analysis 
开发者ID:andi611,项目名称:Self-Supervised-Speech-Pretraining-and-Representation-Learning,代码行数:37,代码来源:observe_speaker.py

示例10: __init__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __init__(self, run_mam, file_path, sets, bucket_size, max_timestep=0, drop=False, mam_config=None):
        super(AcousticDataset, self).__init__(file_path, sets, bucket_size, max_timestep, drop)

        self.run_mam = run_mam
        self.mam_config = mam_config
        self.sample_step = mam_config['max_input_length'] if 'max_input_length' in mam_config else 0
        if self.sample_step > 0: print('[Dataset] - Sampling random segments for training, sample length:', self.sample_step)
        X = self.table['file_path'].tolist()
        X_lens = self.table['length'].tolist()

        # Use bucketing to allow different batch size at run time
        self.X = []
        batch_x, batch_len = [], []

        for x, x_len in zip(X, X_lens):
            batch_x.append(x)
            batch_len.append(x_len)
            
            # Fill in batch_x until batch is full
            if len(batch_x) == bucket_size:
                # Half the batch size if seq too long
                if (bucket_size >= 2) and (max(batch_len) > HALF_BATCHSIZE_TIME):
                    self.X.append(batch_x[:bucket_size//2])
                    self.X.append(batch_x[bucket_size//2:])
                else:
                    self.X.append(batch_x)
                batch_x, batch_len = [], []
        
        # Gather the last batch
        if len(batch_x) > 0:
            self.X.append(batch_x) 
开发者ID:andi611,项目名称:Self-Supervised-Speech-Pretraining-and-Representation-Learning,代码行数:33,代码来源:dataloader.py

示例11: __getitem__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import Dataset [as 别名]
def __getitem__(self, index):
        datum = self.data[index]

        image_root = self.dataset_root / "WIDER_{0}".format(self.split)
        image_path = image_root / "images" / datum['img_path']
        image = Image.open(image_path).convert('RGB')

        if self.split == 'train':
            bboxes = datum['bboxes']

            if self.debug:
                if bboxes.shape[0] == 0:
                    print(image_path)
                print("Dataset index: \t", index)
                print("image path:\t", image_path)

            img, class_map, reg_map, bboxes = self.process_inputs(image,
                                                                  bboxes)

            # convert everything to tensors
            if self.transforms is not None:
                # if img is a byte or uint8 array, it will convert from 0-255 to 0-1
                # this converts from (HxWxC) to (CxHxW) as well
                img = self.transforms(img)

            class_map = torch.from_numpy(class_map)
            reg_map = torch.from_numpy(reg_map)

            return img, class_map, reg_map

        elif self.split == 'val':
            # NOTE Return only the image and the image path.
            # Use the eval_tools to get the final results.
            if self.transforms is not None:
                # Only convert to tensor since we do normalization after rescaling
                img = transforms.functional.to_tensor(image)

            return img, datum['img_path']

        elif self.split == 'test':
            filename = datum['img_path']

            if self.transforms is not None:
                img = self.transforms(image)

            return img, filename 
开发者ID:varunagrawal,项目名称:tiny-faces-pytorch,代码行数:48,代码来源:wider_face.py


注:本文中的torch.utils.data.dataset.Dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。