当前位置: 首页>>代码示例>>Python>>正文


Python config.BATCH_SIZE属性代码示例

本文整理汇总了Python中config.BATCH_SIZE属性的典型用法代码示例。如果您正苦于以下问题:Python config.BATCH_SIZE属性的具体用法?Python config.BATCH_SIZE怎么用?Python config.BATCH_SIZE使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在config的用法示例。


在下文中一共展示了config.BATCH_SIZE属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: setup

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def setup(self, bottom, top):
        """Setup the RoIDataLayer."""
        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)    
        self._batch_size = config.BATCH_SIZE
        self._triplet = self._batch_size/3
        assert self._batch_size % 3 == 0
        self._name_to_top_map = {
            'data': 0,
            'labels': 1}

        self.data_container =  sampledata() 
        self._index = 0

        # data blob: holds a batch of N images, each with 3 channels
        # The height and width (100 x 100) are dummy values
        top[0].reshape(self._batch_size, 3, 224, 224)

        top[1].reshape(self._batch_size) 
开发者ID:luhaofang,项目名称:tripletloss,代码行数:21,代码来源:datalayer.py

示例2: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, val_split, rebuild=False, data_aug=False):
        self.name = 'TF_flowers'
        self.devkit_path = cfg.FLOWERS_PATH
        self.data_path = self.devkit_path
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = cfg.BATCH_SIZE
        self.image_size = cfg.IMAGE_SIZE
        self.rebuild = rebuild
        self.data_aug = data_aug
        self.num_class = 5
        self.classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
        self.class_to_ind = dict(
            list(zip(self.classes, list(range(self.num_class)))))
        self.train_cursor = 0
        self.val_cursor = 0
        self.epoch = 1
        self.gt_labels = None
        self.val_split = val_split
        assert os.path.exists(self.devkit_path), \
            'TF_flowers path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare() 
开发者ID:wenxichen,项目名称:tensorflow_yolo2,代码行数:25,代码来源:TF_flowers.py

示例3: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, rebuild=False, data_aug=True):
        self.name = 'ilsvrc_2017'
        self.devkit_path = cfg.ILSVRC_PATH
        self.data_path = self.devkit_path
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = cfg.BATCH_SIZE
        self.image_size = cfg.IMAGE_SIZE
        self.image_set = image_set
        self.rebuild = rebuild
        self.data_aug = data_aug
        self.cursor = 0
        self.load_classes()
        # self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'VOCdevkit path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare() 
开发者ID:wenxichen,项目名称:tensorflow_yolo2,代码行数:20,代码来源:ilsvrc2017_cls.py

示例4: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self):
        """
        需始终记住:
        small_detector对应下标索引0, medium_detector对应下标索引1,big_detector对应下标索引2
        :param dataset_type: 选择加载训练样本或测试样本,必须是'train' or 'test'
        """
        self.__dataset_path = cfg.DATASET_PATH
        self.__train_input_sizes = cfg.TRAIN_INPUT_SIZES
        self.__strides = np.array(cfg.STRIDES)
        self.__batch_size = cfg.BATCH_SIZE
        self.__classes = cfg.CLASSES
        self.__num_classes = len(self.__classes)
        self.__gt_per_grid = cfg.GT_PER_GRID
        self.__class_to_ind = dict(zip(self.__classes, range(self.__num_classes)))

        annotations_2007 = self.__load_annotations(os.path.join(self.__dataset_path, '2007_trainval'))
        annotations_2012 = self.__load_annotations(os.path.join(self.__dataset_path, '2012_trainval'))
        self.__annotations = annotations_2007 + annotations_2012
        self.__num_samples = len(self.__annotations)
        logging.info(('The number of image for train is:').ljust(50) + str(self.__num_samples))
        self.__num_batchs = np.ceil(self.__num_samples / self.__batch_size)
        self.__batch_count = 0 
开发者ID:PINTO0309,项目名称:PINTO_model_zoo,代码行数:24,代码来源:data.py

示例5: main

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def main():
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    sess = tf.Session()
    batch_size = cfg.BATCH_SIZE
    parameter_path = cfg.PARAMETER_FILE
    lenet = Lenet()
    max_iter = cfg.MAX_ITER


    saver = tf.train.Saver()
    if os.path.exists(parameter_path):
        saver.restore(parameter_path)
    else:
        sess.run(tf.initialize_all_variables())

    for i in range(max_iter):
        batch = mnist.train.next_batch(50)
        if i % 100 == 0:
            train_accuracy = sess.run(lenet.train_accuracy,feed_dict={
                lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1]
            })
            print("step %d, training accuracy %g" % (i, train_accuracy))
        sess.run(lenet.train_op,feed_dict={lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1]})
    save_path = saver.save(sess, parameter_path) 
开发者ID:ganyc717,项目名称:LeNet,代码行数:26,代码来源:Train.py

示例6: setup

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def setup(self, bottom, top):
        """Setup the DataLayer."""

        if cfg.TRIPLET_LOSS:
            self.batch_size = cfg.TRIPLET_BATCH_SIZE
        else:
            self.batch_size = cfg.BATCH_SIZE
        self._name_to_top_map = {
            'data': 0,
            'labels': 1}

        self._index = 0
        self._epoch = 1

        # data blob: holds a batch of N images, each with 3 channels
        # The height and width (100 x 100) are dummy values
        top[0].reshape(self.batch_size, 3, 224, 224)

        top[1].reshape(self.batch_size) 
开发者ID:hizhangp,项目名称:triplet,代码行数:21,代码来源:data_layer.py

示例7: setup

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def setup(self, bottom, top):
        """Setup the TripletSelectLayer."""
        self.triplet = config.BATCH_SIZE/3
        top[0].reshape(self.triplet,shape(bottom[0].data)[1])
        top[1].reshape(self.triplet,shape(bottom[0].data)[1])
        top[2].reshape(self.triplet,shape(bottom[0].data)[1]) 
开发者ID:luhaofang,项目名称:tripletloss,代码行数:8,代码来源:tripletselectlayer.py

示例8: showProgress

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def showProgress(epoch, done=False):

    global last_update

    # First call?
    if not 'batch_count' in cfg.STATS:
        bcnt = 0
    else:
        bcnt = cfg.STATS['batch_count']

    # Calculate number of batches to train
    total_batches = cfg.STATS['sample_count'] // cfg.BATCH_SIZE + 1

    # Current progess
    if not done:
        if bcnt == 0:
            log.p(('EPOCH', epoch, '['), new_line=False)
        else:
            p = bcnt * 100 / total_batches
            if not p % 5 and not p == last_update:
                log.p('=', new_line=False)
                last_update = p
    else:
        log.p(']', new_line=False)

# Clear on first load 
开发者ID:kahst,项目名称:BirdCLEF-Baseline,代码行数:28,代码来源:stats.py

示例9: getDatasetChunk

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def getDatasetChunk(split):

    #get batch-sized chunks of image paths
    for i in xrange(0, len(split), cfg.BATCH_SIZE):
        yield split[i:i+cfg.BATCH_SIZE] 
开发者ID:kahst,项目名称:BirdCLEF-Baseline,代码行数:7,代码来源:batch_generator.py

示例10: getNextImageBatch

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def getNextImageBatch(split, augmentation=True): 

    #fill batch
    for chunk in getDatasetChunk(split):

        #allocate numpy arrays for image data and targets
        x_b = np.zeros((cfg.BATCH_SIZE, cfg.IM_DIM, cfg.IM_SIZE[1], cfg.IM_SIZE[0]), dtype='float32')
        y_b = np.zeros((cfg.BATCH_SIZE, len(cfg.CLASSES)), dtype='float32')
        
        ib = 0
        for sample in chunk:

            try:
            
                #load image data and class label from path
                x, y = loadImageAndTarget(sample, augmentation)

                #pack into batch array
                x_b[ib] = x
                y_b[ib] = y
                ib += 1

            except:
                continue

        #trim to actual size
        x_b = x_b[:ib]
        y_b = y_b[:ib]

        #instead of return, we use yield
        yield x_b, y_b

#Loading images with CPU background threads during GPU forward passes saves a lot of time
#Credit: J. Schlüter (https://github.com/Lasagne/Lasagne/issues/12) 
开发者ID:kahst,项目名称:BirdCLEF-Baseline,代码行数:36,代码来源:batch_generator.py

示例11: generate_datasets

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def generate_datasets():
    train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)
    valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)
    test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)


    # read the original_dataset in the form of batch
    train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)
    valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)
    test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)

    return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count 
开发者ID:calmisential,项目名称:TensorFlow2.0_ResNet,代码行数:14,代码来源:prepare_data.py

示例12: placeholder_inputs

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def placeholder_inputs(config):
    pc_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3))
    color_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3))
    pc_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, config.NUM_POINT_INS, 3))
    group_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT))
    group_indicator_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_GROUP))
    seg_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT))
    bbox_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, 6))
    return pc_pl, color_pl, pc_ins_pl, group_label_pl, group_indicator_pl, seg_label_pl, bbox_ins_pl 
开发者ID:ericyi,项目名称:GSPN,代码行数:11,代码来源:model_rpointnet.py

示例13: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, rebuild=False, data_aug=False,
                 multithread=False, batch_size=cfg.BATCH_SIZE,
                 image_size = cfg.IMAGE_SIZE, RGB=False):
        self.name = 'ilsvrc_2017_cls'
        self.devkit_path = cfg.ILSVRC_PATH
        self.data_path = self.devkit_path
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = batch_size
        self.image_size = image_size
        self.image_set = image_set
        self.rebuild = rebuild
        self.multithread = multithread
        self.data_aug = data_aug
        self.RGB = RGB
        self.load_classes()
        self.cursor = 0
        self.epoch = 1
        self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'ILSVRC path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare()

        if self.multithread:
            self.prepare_multithread()
            self.get = self._get_multithread
        else:
            self.get = self._get 
开发者ID:wenxichen,项目名称:tensorflow_yolo2,代码行数:31,代码来源:ilsvrc2017_cls_multithread.py

示例14: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, batch_size=cfg.BATCH_SIZE, rebuild=False):
        self.name = 'voc_2007'
        self.devkit_path = cfg.PASCAL_PATH
        self.data_path = os.path.join(self.devkit_path, 'VOC2007')
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = batch_size
        self.image_size = cfg.IMAGE_SIZE
        self.cell_size = cfg.S
        self.classes = ('aeroplane', 'bicycle', 'bird', 'boat',
                        'bottle', 'bus', 'car', 'cat', 'chair',
                        'cow', 'diningtable', 'dog', 'horse',
                        'motorbike', 'person', 'pottedplant',
                        'sheep', 'sofa', 'train', 'tvmonitor')
        self.num_class = len(self.classes)
        self.class_to_ind = dict(
            list(zip(self.classes, list(range(self.num_class)))))
        self.flipped = cfg.FLIPPED
        self.image_set = image_set
        self.rebuild = rebuild
        self.cursor = 0
        self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'VOCdevkit path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare() 
开发者ID:wenxichen,项目名称:tensorflow_yolo2,代码行数:28,代码来源:pascal_voc.py

示例15: __init__

# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, rebuild=False,
                 multithread=False, batch_size=cfg.BATCH_SIZE,
                 image_size = cfg.IMAGE_SIZE, random_noise=False):
        self.name = 'ilsvrc_2017_cls'
        self.devkit_path = cfg.ILSVRC_PATH
        self.data_path = self.devkit_path
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = batch_size
        self.image_size = image_size
        self.image_set = image_set
        self.rebuild = rebuild
        self.multithread = multithread
        self.random_noise = random_noise
        self.load_classes()
        self.cursor = 0
        self.epoch = 1
        self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'ILSVRC path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare()

        if self.multithread:
            self.prepare_multithread()
            self.get = self._get_multithread
        else:
            self.get = self._get 
开发者ID:wenxichen,项目名称:tensorflow_yolo2,代码行数:30,代码来源:ilsvrc_cls_multithread_scipy.py


注:本文中的config.BATCH_SIZE属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。