本文整理汇总了Python中tensorpack.dataflow.common.BatchData方法的典型用法代码示例。如果您正苦于以下问题:Python common.BatchData方法的具体用法?Python common.BatchData怎么用?Python common.BatchData使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorpack.dataflow.common
的用法示例。
在下文中一共展示了common.BatchData方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_mnist_data
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_mnist_data(is_train, image_size, batchsize):
ds = MNISTCh('train' if is_train else 'test', shuffle=True)
if is_train:
augs = [
imgaug.RandomApplyAug(imgaug.RandomResize((0.8, 1.2), (0.8, 1.2)), 0.3),
imgaug.RandomApplyAug(imgaug.RotationAndCropValid(15), 0.5),
imgaug.RandomApplyAug(imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 0.25),
imgaug.Resize((224, 224), cv2.INTER_AREA)
]
ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 128*10, multiprocessing.cpu_count())
ds = BatchData(ds, batchsize)
ds = PrefetchData(ds, 256, 4)
else:
# no augmentation, only resizing
augs = [
imgaug.Resize((image_size, image_size), cv2.INTER_CUBIC),
]
ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, batchsize)
ds = PrefetchData(ds, 20, 2)
return ds
示例2: get_ilsvrc_data_alexnet
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_ilsvrc_data_alexnet(is_train, image_size, batchsize, directory):
if is_train:
if not directory.startswith('/'):
ds = ILSVRCTTenthTrain(directory)
else:
ds = ILSVRC12(directory, 'train')
augs = [
imgaug.RandomApplyAug(imgaug.RandomResize((0.9, 1.2), (0.9, 1.2)), 0.7),
imgaug.RandomApplyAug(imgaug.RotationAndCropValid(15), 0.7),
imgaug.RandomApplyAug(imgaug.RandomChooseAug([
imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
imgaug.RandomOrderAug([
imgaug.BrightnessScale((0.8, 1.2), clip=False),
imgaug.Contrast((0.8, 1.2), clip=False),
# imgaug.Saturation(0.4, rgb=True),
]),
]), 0.7),
imgaug.Flip(horiz=True),
imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
imgaug.RandomCrop((224, 224)),
]
ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 1000, multiprocessing.cpu_count())
ds = BatchData(ds, batchsize)
ds = PrefetchData(ds, 10, 4)
else:
if not directory.startswith('/'):
ds = ILSVRCTenthValid(directory)
else:
ds = ILSVRC12(directory, 'val')
ds = AugmentImageComponent(ds, [
imgaug.ResizeShortestEdge(224, cv2.INTER_CUBIC),
imgaug.CenterCrop((224, 224)),
])
ds = PrefetchData(ds, 100, multiprocessing.cpu_count())
ds = BatchData(ds, batchsize)
return ds
示例3: get_dataflow_batch
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
logger.info('dataflow img_path=%s' % img_path)
ds = get_dataflow(path, is_train, img_path=img_path)
ds = BatchData(ds, batchsize)
if is_train:
ds = PrefetchData(ds, 10, 2)
else:
ds = PrefetchData(ds, 50, 2)
return ds
示例4: dataflow
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def dataflow(self, nr_prefetch=1000, nr_thread=1):
ds = self
ds = BatchData(ds, self.batch_size)
ds = PrefetchData(ds, nr_prefetch, nr_thread)
return ds
示例5: get_remote_dataflow
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_remote_dataflow(port, nr_prefetch=1000, nr_thread=1):
ipc = 'ipc:///tmp/ipc-socket'
tcp = 'tcp://0.0.0.0:%d' % port
data_loader = RemoteDataZMQ(ipc, tcp, hwm=10000)
data_loader = BatchData(data_loader, batch_size=hp.train.batch_size)
data_loader = PrefetchData(data_loader, nr_prefetch, nr_thread)
return data_loader
示例6: get_dataflow_batch
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def get_dataflow_batch(path, is_train, batchsize, img_path=None):
logger.info('dataflow img_path=%s' % img_path)
ds = get_dataflow(path, is_train, img_path=img_path)
ds = BatchData(ds, batchsize)
# if is_train:
# ds = PrefetchData(ds, 10, 2)
# else:
# ds = PrefetchData(ds, 50, 2)
return ds
示例7: __call__
# 需要导入模块: from tensorpack.dataflow import common [as 别名]
# 或者: from tensorpack.dataflow.common import BatchData [as 别名]
def __call__(self, n_prefetch=1000, n_thread=1):
df = self
df = BatchData(df, self.batch_size)
df = PrefetchData(df, n_prefetch, n_thread)
return df