當前位置: 首頁>>代碼示例>>Python>>正文


Python dataflow.BatchData方法代碼示例

本文整理匯總了Python中tensorpack.dataflow.BatchData方法的典型用法代碼示例。如果您正苦於以下問題:Python dataflow.BatchData方法的具體用法?Python dataflow.BatchData怎麽用?Python dataflow.BatchData使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorpack.dataflow的用法示例。


在下文中一共展示了dataflow.BatchData方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000,
            collate_fn=default_collate,  drop_last=False, cuda=False):
        # enumerate standard imagenet augmentors
        imagenet_augmentors = fbresnet_augmentor(mode == 'train')

        # load the lmdb if we can find it
        lmdb_loc = os.path.join(os.environ['IMAGENET'],'ILSVRC-%s.lmdb'%mode)
        ds = td.LMDBData(lmdb_loc, shuffle=False)
        ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.LMDBDataPoint(ds)
        ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
        ds = td.AugmentImageComponent(ds, imagenet_augmentors)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.cuda = cuda
        #self.drop_last = drop_last 
開發者ID:BayesWatch,項目名稱:sequential-imagenet-dataloader,代碼行數:23,代碼來源:data.py

示例2: get_imagenet_dataflow

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def get_imagenet_dataflow(datadir,
                          is_train,
                          batch_size,
                          augmentors,
                          parallel=None):
    """
    See explanations in the tutorial:
    http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
    """
    assert datadir is not None
    assert isinstance(augmentors, list)
    if parallel is None:
        parallel = min(40, multiprocessing.cpu_count() // 2)  # assuming hyperthreading
    if is_train:
        ds = dataset.ILSVRC12(datadir, "train", shuffle=True)
        ds = AugmentImageComponent(ds, augmentors, copy=False)
        if parallel < 16:
            logging.warning("DataFlow may become the bottleneck when too few processes are used.")
        ds = PrefetchDataZMQ(ds, parallel)
        ds = BatchData(ds, batch_size, remainder=False)
    else:
        ds = dataset.ILSVRC12Files(datadir, "val", shuffle=False)
        aug = imgaug.AugmentorList(augmentors)

        def mapf(dp):
            fname, cls = dp
            im = cv2.imread(fname, cv2.IMREAD_COLOR)
            im = np.flip(im, axis=2)
            # print("fname={}".format(fname))
            im = aug.augment(im)
            return im, cls
        ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
        # ds = MapData(ds, mapf)
        ds = BatchData(ds, batch_size, remainder=True)
        ds = PrefetchDataZMQ(ds, 1)
        # ds = PrefetchData(ds, 1)
    return ds 
開發者ID:osmr,項目名稱:imgclsmob,代碼行數:39,代碼來源:utils_tp.py

示例3: build_iter

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def build_iter(self):

        ds = DataFromGenerator(self.generator)
        ds = BatchData(ds, self.batch_size)
        ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
        ds.reset_state()
        ds = ds.get_data()
        return ds 
開發者ID:610265158,項目名稱:face_landmark,代碼行數:10,代碼來源:dataietr.py

示例4: preprocess_data_flow

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def preprocess_data_flow(ds, options, is_train, do_multiprocess=False):
    ds_size = ds.size()
    while options.batch_size > ds_size:
        options.batch_size //= 2
    ds = BatchData(ds, max(1, options.batch_size // options.nr_gpu),
        remainder=not is_train)
    if do_multiprocess:
        ds = PrefetchData(ds, 5, 5)
    return ds 
開發者ID:microsoft,項目名稱:petridishnn,代碼行數:11,代碼來源:misc.py

示例5: critic_dataflow_factory

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def critic_dataflow_factory(ctrl, data, is_train):
    """
    Generate a critic dataflow
    """
    if ctrl.critic_type == CriticTypes.CONV:
        ds = ConvCriticDataFlow(data, shuffle=is_train, max_depth=ctrl.controller_max_depth)
        ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=False)
    elif ctrl.critic_type == CriticTypes.LSTM:
        ds = LSTMCriticDataFlow(data, shuffle=is_train)
        ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=True)
    return ds 
開發者ID:microsoft,項目名稱:petridishnn,代碼行數:13,代碼來源:critic.py

示例6: get_data

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def get_data(split, option):
    is_training = split == 'train'
    parallel = multiprocessing.cpu_count() // 2
    ds = get_data_flow(split, is_training, option)
    augmentors = fbresnet_augmentor(is_training, option)
    ds = AugmentImageCoordinates(ds, augmentors, coords_index=2, copy=False)
    if is_training:
        ds = PrefetchDataZMQ(ds, parallel)
    ds = BatchData(ds, option.batch_size, remainder=not is_training)
    return ds 
開發者ID:junsukchoe,項目名稱:ADL,代碼行數:12,代碼來源:data_loader.py

示例7: build_iter

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def build_iter(self):


        ds = DataFromGenerator(self.generator)

        ds = BatchData(ds, self.num_gpu *  self.batch_size)

        ds = MultiProcessPrefetchData(ds, self.prefetch_size, self.process_num)
        ds.reset_state()
        ds = ds.get_data()
        return ds 
開發者ID:610265158,項目名稱:faceboxes-tensorflow,代碼行數:13,代碼來源:dataietr.py

示例8: get_data

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def get_data(lmdb_path, txt_path):

        if txt_path:
            ds = arod_dataflow_from_txt.Triplets(lmdb_path, txt_path, IMAGE_HEIGHT, IMAGE_WIDTH)
        else:
            ds = arod_provider.Triplets(lmdb_path, IMAGE_HEIGHT, IMAGE_WIDTH)

        ds.reset_state()
        cpu = min(10, multiprocessing.cpu_count())
        ds = PrefetchDataZMQ(ds, cpu)
        ds = BatchData(ds, BATCH_SIZE)
        return ds 
開發者ID:cgtuebingen,項目名稱:will-people-like-your-image,代碼行數:14,代碼來源:resnet50_for_embedding.py

示例9: get_data

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def get_data():
    def f(dp):
        im = dp[0][:, :, None]
        onehot = np.eye(10)[dp[1]]
        return [im, onehot]

    train = BatchData(MapData(dataset.Mnist('train'), f), 128)
    test = BatchData(MapData(dataset.Mnist('test'), f), 256)
    return train, test 
開發者ID:tensorpack,項目名稱:tensorpack,代碼行數:11,代碼來源:mnist-keras-v2.py

示例10: get_test_data

# 需要導入模塊: from tensorpack import dataflow [as 別名]
# 或者: from tensorpack.dataflow import BatchData [as 別名]
def get_test_data(batch=128):
    ds = dataset.Mnist('test')
    ds = BatchData(ds, batch)
    return ds 
開發者ID:tensorpack,項目名稱:tensorpack,代碼行數:6,代碼來源:embedding_data.py


注:本文中的tensorpack.dataflow.BatchData方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。