本文整理汇总了Python中tensorpack.dataflow.BatchData方法的典型用法代码示例。如果您正苦于以下问题:Python dataflow.BatchData方法的具体用法?Python dataflow.BatchData怎么用?Python dataflow.BatchData使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorpack.dataflow
的用法示例。
在下文中一共展示了dataflow.BatchData方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000,
collate_fn=default_collate, drop_last=False, cuda=False):
# enumerate standard imagenet augmentors
imagenet_augmentors = fbresnet_augmentor(mode == 'train')
# load the lmdb if we can find it
lmdb_loc = os.path.join(os.environ['IMAGENET'],'ILSVRC-%s.lmdb'%mode)
ds = td.LMDBData(lmdb_loc, shuffle=False)
ds = td.LocallyShuffleData(ds, cache)
ds = td.PrefetchData(ds, 5000, 1)
ds = td.LMDBDataPoint(ds)
ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
ds = td.AugmentImageComponent(ds, imagenet_augmentors)
ds = td.PrefetchDataZMQ(ds, num_workers)
self.ds = td.BatchData(ds, batch_size)
self.ds.reset_state()
self.batch_size = batch_size
self.num_workers = num_workers
self.cuda = cuda
#self.drop_last = drop_last
示例2: get_imagenet_dataflow
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def get_imagenet_dataflow(datadir,
is_train,
batch_size,
augmentors,
parallel=None):
"""
See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
"""
assert datadir is not None
assert isinstance(augmentors, list)
if parallel is None:
parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading
if is_train:
ds = dataset.ILSVRC12(datadir, "train", shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logging.warning("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, "val", shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = np.flip(im, axis=2)
# print("fname={}".format(fname))
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
# ds = MapData(ds, mapf)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
# ds = PrefetchData(ds, 1)
return ds
示例3: build_iter
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def build_iter(self):
ds = DataFromGenerator(self.generator)
ds = BatchData(ds, self.batch_size)
ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
ds.reset_state()
ds = ds.get_data()
return ds
示例4: preprocess_data_flow
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def preprocess_data_flow(ds, options, is_train, do_multiprocess=False):
ds_size = ds.size()
while options.batch_size > ds_size:
options.batch_size //= 2
ds = BatchData(ds, max(1, options.batch_size // options.nr_gpu),
remainder=not is_train)
if do_multiprocess:
ds = PrefetchData(ds, 5, 5)
return ds
示例5: critic_dataflow_factory
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def critic_dataflow_factory(ctrl, data, is_train):
"""
Generate a critic dataflow
"""
if ctrl.critic_type == CriticTypes.CONV:
ds = ConvCriticDataFlow(data, shuffle=is_train, max_depth=ctrl.controller_max_depth)
ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=False)
elif ctrl.critic_type == CriticTypes.LSTM:
ds = LSTMCriticDataFlow(data, shuffle=is_train)
ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=True)
return ds
示例6: get_data
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def get_data(split, option):
is_training = split == 'train'
parallel = multiprocessing.cpu_count() // 2
ds = get_data_flow(split, is_training, option)
augmentors = fbresnet_augmentor(is_training, option)
ds = AugmentImageCoordinates(ds, augmentors, coords_index=2, copy=False)
if is_training:
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, option.batch_size, remainder=not is_training)
return ds
示例7: build_iter
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def build_iter(self):
ds = DataFromGenerator(self.generator)
ds = BatchData(ds, self.num_gpu * self.batch_size)
ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
ds.reset_state()
ds = ds.get_data()
return ds
示例8: get_data
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def get_data(lmdb_path, txt_path):
if txt_path:
ds = arod_dataflow_from_txt.Triplets(lmdb_path, txt_path, IMAGE_HEIGHT, IMAGE_WIDTH)
else:
ds = arod_provider.Triplets(lmdb_path, IMAGE_HEIGHT, IMAGE_WIDTH)
ds.reset_state()
cpu = min(10, multiprocessing.cpu_count())
ds = PrefetchDataZMQ(ds, cpu)
ds = BatchData(ds, BATCH_SIZE)
return ds
示例9: get_data
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def get_data():
def f(dp):
im = dp[0][:, :, None]
onehot = np.eye(10)[dp[1]]
return [im, onehot]
train = BatchData(MapData(dataset.Mnist('train'), f), 128)
test = BatchData(MapData(dataset.Mnist('test'), f), 256)
return train, test
示例10: get_test_data
# 需要导入模块: from tensorpack import dataflow [as 别名]
# 或者: from tensorpack.dataflow import BatchData [as 别名]
def get_test_data(batch=128):
ds = dataset.Mnist('test')
ds = BatchData(ds, batch)
return ds