本文整理汇总了Python中tensorpack.dataflow.MultiThreadMapData方法的典型用法代码示例。如果您正苦于以下问题:Python dataflow.MultiThreadMapData方法的具体用法?Python dataflow.MultiThreadMapData怎么用?Python dataflow.MultiThreadMapData使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorpack.dataflow
的用法示例。
在下文中一共展示了dataflow.MultiThreadMapData方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_dataflow
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import MultiThreadMapData [as 别名]
def get_dataflow(path, is_train, img_path=None):
ds = CocoPose(path, img_path, is_train) # read data from lmdb
if is_train:
ds = MapData(ds, read_image_url)
ds = MapDataComponent(ds, pose_random_scale)
ds = MapDataComponent(ds, pose_rotation)
ds = MapDataComponent(ds, pose_flip)
ds = MapDataComponent(ds, pose_resize_shortestedge_random)
ds = MapDataComponent(ds, pose_crop_random)
ds = MapData(ds, pose_to_img)
# augs = [
# imgaug.RandomApplyAug(imgaug.RandomChooseAug([
# imgaug.GaussianBlur(max_size=3)
# ]), 0.7)
# ]
# ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 4)
else:
ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
ds = MapDataComponent(ds, pose_crop_center)
ds = MapData(ds, pose_to_img)
ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)
return ds
示例2: get_dataflow
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import MultiThreadMapData [as 别名]
def get_dataflow(path, is_train, img_path=None):
ds = CocoPose(path, img_path, is_train) # read data from lmdb
if is_train:
ds = MapData(ds, read_image_url)
ds = MapDataComponent(ds, pose_random_scale)
ds = MapDataComponent(ds, pose_rotation)
ds = MapDataComponent(ds, pose_flip)
ds = MapDataComponent(ds, pose_resize_shortestedge_random)
ds = MapDataComponent(ds, pose_crop_random)
ds = MapData(ds, pose_to_img)
# augs = [
# imgaug.RandomApplyAug(imgaug.RandomChooseAug([
# imgaug.GaussianBlur(max_size=3)
# ]), 0.7)
# ]
# ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 1000, multiprocessing.cpu_count() * 1)
else:
ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
ds = MapDataComponent(ds, pose_crop_center)
ds = MapData(ds, pose_to_img)
ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)
return ds
示例3: get_dataflow
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import MultiThreadMapData [as 别名]
def get_dataflow(path, is_train, img_path=None):
ds = CocoPose(path, img_path, is_train) # read data from lmdb
if is_train:
ds = MapData(ds, read_image_url)
ds = MapDataComponent(ds, pose_random_scale)
ds = MapDataComponent(ds, pose_rotation)
ds = MapDataComponent(ds, pose_flip)
ds = MapDataComponent(ds, pose_resize_shortestedge_random)
ds = MapDataComponent(ds, pose_crop_random)
ds = MapData(ds, pose_to_img)
# augs = [
# imgaug.RandomApplyAug(imgaug.RandomChooseAug([
# imgaug.GaussianBlur(max_size=3)
# ]), 0.7)
# ]
# ds = AugmentImageComponent(ds, augs)
ds = PrefetchData(ds, 1000, multiprocessing.cpu_count()-1)
else:
ds = MultiThreadMapData(ds, nr_thread=16, map_func=read_image_url, buffer_size=1000)
ds = MapDataComponent(ds, pose_resize_shortestedge_fixed)
ds = MapDataComponent(ds, pose_crop_center)
ds = MapData(ds, pose_to_img)
ds = PrefetchData(ds, 100, multiprocessing.cpu_count() // 4)
return ds
示例4: get_imagenet_dataflow
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import MultiThreadMapData [as 别名]
def get_imagenet_dataflow(datadir,
is_train,
batch_size,
augmentors,
parallel=None):
"""
See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
"""
assert datadir is not None
assert isinstance(augmentors, list)
if parallel is None:
parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading
if is_train:
ds = dataset.ILSVRC12(datadir, "train", shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logging.warning("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, "val", shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = np.flip(im, axis=2)
# print("fname={}".format(fname))
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
# ds = MapData(ds, mapf)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
# ds = PrefetchData(ds, 1)
return ds