当前位置: 首页>>代码示例>>Python>>正文


Python common.BatchData方法代码示例

本文整理汇总了Python中tensorpack.dataflow.common.BatchData方法的典型用法代码示例。如果您正苦于以下问题:Python common.BatchData方法的具体用法?Python common.BatchData怎么用?Python common.BatchData使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorpack.dataflow.common的用法示例。


在下文中一共展示了common.BatchData方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_mnist_data

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_mnist_data(is_train, image_size, batchsize):
    ds = MNISTCh('train' if is_train else 'test', shuffle=True)

    if is_train:
        augs = [
            imgaug.RandomApplyAug(imgaug.RandomResize((0.8, 1.2), (0.8, 1.2)), 0.3),
            imgaug.RandomApplyAug(imgaug.RotationAndCropValid(15), 0.5),
            imgaug.RandomApplyAug(imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 0.25),
            imgaug.Resize((224, 224), cv2.INTER_AREA)
        ]
        ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 128*10, multiprocessing.cpu_count())
        ds = BatchData(ds, batchsize)
        ds = PrefetchData(ds, 256, 4)
    else:
        # no augmentation, only resizing
        augs = [
            imgaug.Resize((image_size, image_size), cv2.INTER_CUBIC),
        ]
        ds = AugmentImageComponent(ds, augs)
        ds = BatchData(ds, batchsize)
        ds = PrefetchData(ds, 20, 2)
    return ds 
开发者ID:ildoonet,项目名称:tf-lcnn,代码行数:25,代码来源:data_feeder.py

示例2: get_ilsvrc_data_alexnet

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_ilsvrc_data_alexnet(is_train, image_size, batchsize, directory):
    if is_train:
        if not directory.startswith('/'):
            ds = ILSVRCTTenthTrain(directory)
        else:
            ds = ILSVRC12(directory, 'train')
        augs = [
            imgaug.RandomApplyAug(imgaug.RandomResize((0.9, 1.2), (0.9, 1.2)), 0.7),
            imgaug.RandomApplyAug(imgaug.RotationAndCropValid(15), 0.7),
            imgaug.RandomApplyAug(imgaug.RandomChooseAug([
                imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
                imgaug.RandomOrderAug([
                    imgaug.BrightnessScale((0.8, 1.2), clip=False),
                    imgaug.Contrast((0.8, 1.2), clip=False),
                    # imgaug.Saturation(0.4, rgb=True),
                ]),
            ]), 0.7),
            imgaug.Flip(horiz=True),

            imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
            imgaug.RandomCrop((224, 224)),
        ]
        ds = AugmentImageComponent(ds, augs)
        ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())
        ds = BatchData(ds, batchsize)
        ds = PrefetchData(ds, 10, 4)
    else:
        if not directory.startswith('/'):
            ds = ILSVRCTenthValid(directory)
        else:
            ds = ILSVRC12(directory, 'val')
        ds = AugmentImageComponent(ds, [
            imgaug.ResizeShortestEdge(224, cv2.INTER_CUBIC),
            imgaug.CenterCrop((224, 224)),
        ])
        ds = PrefetchData(ds, 100, multiprocessing.cpu_count())
        ds = BatchData(ds, batchsize)

    return ds 
开发者ID:ildoonet,项目名称:tf-lcnn,代码行数:41,代码来源:data_feeder.py

示例3: get_dataflow_batch

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
    logger.info('dataflow img_path=%s' % img_path)
    ds = get_dataflow(path, is_train, img_path=img_path)
    ds = BatchData(ds, batchsize)
    if is_train:
        ds = PrefetchData(ds, 10, 2)
    else:
        ds = PrefetchData(ds, 50, 2)

    return ds 
开发者ID:SrikanthVelpuri,项目名称:tf-pose,代码行数:12,代码来源:pose_dataset.py

示例4: dataflow

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def dataflow(self, nr_prefetch=1000, nr_thread=1):
        ds = self
        ds = BatchData(ds, self.batch_size)
        ds = PrefetchData(ds, nr_prefetch, nr_thread)
        return ds 
开发者ID:andabi,项目名称:voice-vector,代码行数:7,代码来源:data_load.py

示例5: get_remote_dataflow

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_remote_dataflow(port, nr_prefetch=1000, nr_thread=1):
    ipc = 'ipc:///tmp/ipc-socket'
    tcp = 'tcp://0.0.0.0:%d' % port
    data_loader = RemoteDataZMQ(ipc, tcp, hwm=10000)
    data_loader = BatchData(data_loader, batch_size=hp.train.batch_size)
    data_loader = PrefetchData(data_loader, nr_prefetch, nr_thread)
    return data_loader 
开发者ID:andabi,项目名称:voice-vector,代码行数:9,代码来源:train.py

示例6: get_dataflow_batch

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
    logger.info('dataflow img_path=%s' % img_path)
    ds = get_dataflow(path, is_train, img_path=img_path)
    ds = BatchData(ds, batchsize)
    # if is_train:
    #     ds = PrefetchData(ds, 10, 2)
    # else:
    #     ds = PrefetchData(ds, 50, 2)

    return ds 
开发者ID:PINTO0309,项目名称:MobileNetV2-PoseEstimation,代码行数:12,代码来源:pose_dataset.py

示例7: __call__

# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def __call__(self, n_prefetch=1000, n_thread=1):
        df = self
        df = BatchData(df, self.batch_size)
        df = PrefetchData(df, n_prefetch, n_thread)
        return df 
开发者ID:andabi,项目名称:deep-voice-conversion,代码行数:7,代码来源:data_load.py


注:本文中的tensorpack.dataflow.common.BatchData方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。