当前位置: 首页>>代码示例>>Python>>正文


Python data.DataLoader方法代码示例

本文整理汇总了Python中mxnet.gluon.data.DataLoader方法的典型用法代码示例。如果您正苦于以下问题:Python data.DataLoader方法的具体用法?Python data.DataLoader怎么用?Python data.DataLoader使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet.gluon.data的用法示例。


在下文中一共展示了data.DataLoader方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_imagenet_iterator

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'):
    """Dataset loader with preprocessing."""
    train_dir = os.path.join(root, 'train')
    train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
    logging.info("Loading image folder %s, this may take a bit long...", train_dir)
    train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
    train_data = DataLoader(train_dataset, batch_size, shuffle=True,
                            last_batch='discard', num_workers=num_workers)
    val_dir = os.path.join(root, 'val')
    if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))):
        user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1'
        raise ValueError(user_warning)
    logging.info("Loading image folder %s, this may take a bit long...", val_dir)
    val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
    val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
    return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:18,代码来源:data.py

示例2: get_caltech101_iterator

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def get_caltech101_iterator(batch_size, num_workers, dtype):
    def transform(image, label):
        # resize the shorter edge to 224, the longer edge will be greater or equal to 224
        resized = mx.image.resize_short(image, 224)
        # center and crop an area of size (224,224)
        cropped, crop_info = mx.image.center_crop(resized, (224, 224))
        # transpose the channels to be (3,224,224)
        transposed = mx.nd.transpose(cropped, (2, 0, 1))
        return transposed, label

    training_path, testing_path = get_caltech101_data()
    dataset_train = ImageFolderDataset(root=training_path, transform=transform)
    dataset_test = ImageFolderDataset(root=testing_path, transform=transform)

    train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
    test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
    return DataLoaderIter(train_data), DataLoaderIter(test_data) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:19,代码来源:data.py

示例3: create_loader

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def create_loader(self):
        """
        Create data loader
        :return: data loaders
        """
        cpus = cpu_count()
        train_tforms, eval_tforms = self.create_transformer()

        if 'digits' in self.args.cfg:
            trs_set, trt_set, tes_set, tet_set = self.create_digits_datasets(train_tforms, eval_tforms)
        elif 'office' in self.args.cfg:
            trs_set, trt_set, tes_set, tet_set = self.create_office_datasets(train_tforms, eval_tforms)
        elif 'visda' in self.args.cfg:
            trs_set, trt_set, tes_set, tet_set = self.create_visda_datasets(train_tforms, eval_tforms)
        else:
            raise NotImplementedError

        self.train_src_loader = DataLoader(trs_set, self.args.bs, shuffle=True, num_workers=cpus)
        self.train_tgt_loader = DataLoader(trt_set, self.args.bs, shuffle=True, num_workers=cpus)
        self.test_src_loader = DataLoader(tes_set, self.args.bs, shuffle=False, num_workers=cpus)
        self.test_tgt_loader = DataLoader(tet_set, self.args.bs, shuffle=False, num_workers=cpus) 
开发者ID:aws-samples,项目名称:d-SNE,代码行数:23,代码来源:training_sda.py

示例4: create_loader

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def create_loader(self):
        """
        Create data loader
        :return: data loaders
        """
        cpus = cpu_count()
        train_tforms, eval_tforms = self.create_transformer()

        if 'digits' in self.args.cfg:
            tr_slu_set, tes_set, tet_set = self.create_digits_datasets(train_tforms, eval_tforms)
        elif 'visda' in self.args.cfg:
            tr_slu_set, tes_set, tet_set = self.create_visda_datasets(train_tforms, eval_tforms)
        else:
            raise NotImplementedError

        self.train_slu_loader = DataLoader(tr_slu_set, self.args.bs, shuffle=True, num_workers=cpus)
        self.test_src_loader = DataLoader(tes_set, self.args.bs, shuffle=False, num_workers=cpus)
        self.test_tgt_loader = DataLoader(tet_set, self.args.bs, shuffle=False, num_workers=cpus) 
开发者ID:aws-samples,项目名称:d-SNE,代码行数:20,代码来源:training_ssda.py

示例5: gluon_random_data_run

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def gluon_random_data_run():
    mlflow.gluon.autolog()

    with mlflow.start_run() as run:
        data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")
        validation = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(model.collect_params(), "adam",
                          optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
        est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(), trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3, val_data=validation)
    client = mlflow.tracking.MlflowClient()
    return client.get_run(run.info.run_id) 
开发者ID:mlflow,项目名称:mlflow,代码行数:25,代码来源:test_gluon_autolog.py

示例6: test_autolog_ends_auto_created_run

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_autolog_ends_auto_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    model = HybridSequential()
    model.add(Dense(64, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()

    trainer = Trainer(model.collect_params(), "adam",
                      optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
    est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                              metrics=Accuracy(), trainer=trainer)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(data, epochs=3)

    assert mlflow.active_run() is None 
开发者ID:mlflow,项目名称:mlflow,代码行数:24,代码来源:test_gluon_autolog.py

示例7: test_autolog_persists_manually_created_run

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_autolog_persists_manually_created_run():
    mlflow.gluon.autolog()

    data = DataLoader(LogsDataset(), batch_size=128, last_batch="discard")

    with mlflow.start_run() as run:

        model = HybridSequential()
        model.add(Dense(64, activation="relu"))
        model.add(Dense(64, activation="relu"))
        model.add(Dense(10))
        model.initialize()
        model.hybridize()
        trainer = Trainer(model.collect_params(), "adam",
                          optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
        est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                                  metrics=Accuracy(), trainer=trainer)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            est.fit(data, epochs=3)

        assert mlflow.active_run().info.run_id == run.info.run_id 
开发者ID:mlflow,项目名称:mlflow,代码行数:25,代码来源:test_gluon_autolog.py

示例8: gluon_model

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def gluon_model(model_data):
    train_data, train_label, _ = model_data
    train_data_loader = DataLoader(list(zip(train_data, train_label)),
                                   batch_size=128, last_batch="discard")
    model = HybridSequential()
    model.add(Dense(128, activation="relu"))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(10))
    model.initialize()
    model.hybridize()
    trainer = Trainer(model.collect_params(), "adam",
                      optimizer_params={"learning_rate": .001, "epsilon": 1e-07})
    est = estimator.Estimator(net=model, loss=SoftmaxCrossEntropyLoss(),
                              metrics=Accuracy(), trainer=trainer)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        est.fit(train_data_loader, epochs=3)
    return model 
开发者ID:mlflow,项目名称:mlflow,代码行数:20,代码来源:test_gluon_model_export.py

示例9: load_data_fashion_mnist

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(
        '~', '.mxnet', 'datasets', 'fashion-mnist')):
    """Download the fashion mnist dataset and then load into memory."""
    root = os.path.expanduser(root)
    transformer = []
    if resize:
        transformer += [gdata.vision.transforms.Resize(resize)]
    transformer += [gdata.vision.transforms.ToTensor()]
    transformer = gdata.vision.transforms.Compose(transformer)

    mnist_train = gdata.vision.FashionMNIST(root=root, train=True)
    mnist_test = gdata.vision.FashionMNIST(root=root, train=False)
    num_workers = 0 if sys.platform.startswith('win32') else 4

    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),
                                  batch_size, shuffle=True,
                                  num_workers=num_workers)
    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),
                                 batch_size, shuffle=False,
                                 num_workers=num_workers)
    return train_iter, test_iter 
开发者ID:d2l-ai,项目名称:d2l-zh,代码行数:23,代码来源:utils.py

示例10: test_array_dataset

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_array_dataset():
    X = np.random.uniform(size=(10, 20))
    Y = np.random.uniform(size=(10,))
    dataset = gluon.data.ArrayDataset(X, Y)
    loader = gluon.data.DataLoader(dataset, 2)
    for i, (x, y) in enumerate(loader):
        assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
        assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2])

    dataset = gluon.data.ArrayDataset(X)
    loader = gluon.data.DataLoader(dataset, 2)

    for i, x in enumerate(loader):
        assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2]) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:16,代码来源:test_gluon_data.py

示例11: test_recordimage_dataset

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_recordimage_dataset():
    recfile = prepare_record()
    fn = lambda x, y : (x, y)
    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
    loader = gluon.data.DataLoader(dataset, 1)

    for i, (x, y) in enumerate(loader):
        assert x.shape[0] == 1 and x.shape[3] == 3
        assert y.asscalar() == i 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:11,代码来源:test_gluon_data.py

示例12: test_recordimage_dataset_with_data_loader_multiworker

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_recordimage_dataset_with_data_loader_multiworker():
    recfile = prepare_record()
    dataset = gluon.data.vision.ImageRecordDataset(recfile)
    loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

    for i, (x, y) in enumerate(loader):
        assert x.shape[0] == 1 and x.shape[3] == 3
        assert y.asscalar() == i

    # with transform
    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
    loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

    for i, (x, y) in enumerate(loader):
        assert x.shape[0] == 1 and x.shape[3] == 3
        assert y.asscalar() == i

    # try limit recursion depth
    import sys
    old_limit = sys.getrecursionlimit()
    sys.setrecursionlimit(500)  # this should be smaller than any default value used in python
    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
    loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

    for i, (x, y) in enumerate(loader):
        assert x.shape[0] == 1 and x.shape[3] == 3
        assert y.asscalar() == i
    sys.setrecursionlimit(old_limit) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:30,代码来源:test_gluon_data.py

示例13: test_multi_worker

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_multi_worker():
    data = Dataset()
    loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5)
    for i, batch in enumerate(loader):
        assert (batch.asnumpy() == i).all() 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:7,代码来源:test_gluon_data.py

示例14: test_multi_worker_forked_data_loader

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def test_multi_worker_forked_data_loader():
    data = _Dummy(False)
    loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2)
    for epoch in range(1):
        for i, data in enumerate(loader):
            pass

    data = _Dummy(True)
    loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2)
    for epoch in range(1):
        for i, data in enumerate(loader):
            pass 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:14,代码来源:test_gluon_data.py

示例15: __init__

# 需要导入模块: from mxnet.gluon import data [as 别名]
# 或者: from mxnet.gluon.data import DataLoader [as 别名]
def __init__(self,
                 dataset,
                 batch_size,
                 collate_fn=collate,
                 seed=0,
                 shuffle=True,
                 split_name='fold10',
                 fold_idx=0,
                 split_ratio=0.7):

        self.shuffle = shuffle
        self.seed = seed

        labels = [l for _, l in dataset]

        if split_name == 'fold10':
            train_idx, valid_idx = self._split_fold10(
                labels, fold_idx, seed, shuffle)
        elif split_name == 'rand':
            train_idx, valid_idx = self._split_rand(
                labels, split_ratio, seed, shuffle)
        else:
            raise NotImplementedError()

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        self.train_loader = DataLoader(
            dataset, sampler=train_sampler,
            batch_size=batch_size, batchify_fn=collate_fn)
        self.valid_loader = DataLoader(
            dataset, sampler=valid_sampler,
            batch_size=batch_size, batchify_fn=collate_fn) 
开发者ID:dmlc,项目名称:dgl,代码行数:35,代码来源:dataloader.py


注:本文中的mxnet.gluon.data.DataLoader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。