本文整理汇总了Python中config.BATCH_SIZE属性的典型用法代码示例。如果您正苦于以下问题:Python config.BATCH_SIZE属性的具体用法?Python config.BATCH_SIZE怎么用?Python config.BATCH_SIZE使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类config
的用法示例。
在下文中一共展示了config.BATCH_SIZE属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: setup
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def setup(self, bottom, top):
"""Setup the RoIDataLayer."""
# parse the layer parameter string, which must be valid YAML
layer_params = yaml.load(self.param_str_)
self._batch_size = config.BATCH_SIZE
self._triplet = self._batch_size/3
assert self._batch_size % 3 == 0
self._name_to_top_map = {
'data': 0,
'labels': 1}
self.data_container = sampledata()
self._index = 0
# data blob: holds a batch of N images, each with 3 channels
# The height and width (100 x 100) are dummy values
top[0].reshape(self._batch_size, 3, 224, 224)
top[1].reshape(self._batch_size)
示例2: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, val_split, rebuild=False, data_aug=False):
self.name = 'TF_flowers'
self.devkit_path = cfg.FLOWERS_PATH
self.data_path = self.devkit_path
self.cache_path = cfg.CACHE_PATH
self.batch_size = cfg.BATCH_SIZE
self.image_size = cfg.IMAGE_SIZE
self.rebuild = rebuild
self.data_aug = data_aug
self.num_class = 5
self.classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
self.class_to_ind = dict(
list(zip(self.classes, list(range(self.num_class)))))
self.train_cursor = 0
self.val_cursor = 0
self.epoch = 1
self.gt_labels = None
self.val_split = val_split
assert os.path.exists(self.devkit_path), \
'TF_flowers path does not exist: {}'.format(self.devkit_path)
assert os.path.exists(self.data_path), \
'Path does not exist: {}'.format(self.data_path)
self.prepare()
示例3: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, rebuild=False, data_aug=True):
self.name = 'ilsvrc_2017'
self.devkit_path = cfg.ILSVRC_PATH
self.data_path = self.devkit_path
self.cache_path = cfg.CACHE_PATH
self.batch_size = cfg.BATCH_SIZE
self.image_size = cfg.IMAGE_SIZE
self.image_set = image_set
self.rebuild = rebuild
self.data_aug = data_aug
self.cursor = 0
self.load_classes()
# self.gt_labels = None
assert os.path.exists(self.devkit_path), \
'VOCdevkit path does not exist: {}'.format(self.devkit_path)
assert os.path.exists(self.data_path), \
'Path does not exist: {}'.format(self.data_path)
self.prepare()
示例4: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self):
"""
需始终记住:
small_detector对应下标索引0, medium_detector对应下标索引1,big_detector对应下标索引2
:param dataset_type: 选择加载训练样本或测试样本,必须是'train' or 'test'
"""
self.__dataset_path = cfg.DATASET_PATH
self.__train_input_sizes = cfg.TRAIN_INPUT_SIZES
self.__strides = np.array(cfg.STRIDES)
self.__batch_size = cfg.BATCH_SIZE
self.__classes = cfg.CLASSES
self.__num_classes = len(self.__classes)
self.__gt_per_grid = cfg.GT_PER_GRID
self.__class_to_ind = dict(zip(self.__classes, range(self.__num_classes)))
annotations_2007 = self.__load_annotations(os.path.join(self.__dataset_path, '2007_trainval'))
annotations_2012 = self.__load_annotations(os.path.join(self.__dataset_path, '2012_trainval'))
self.__annotations = annotations_2007 + annotations_2012
self.__num_samples = len(self.__annotations)
logging.info(('The number of image for train is:').ljust(50) + str(self.__num_samples))
self.__num_batchs = np.ceil(self.__num_samples / self.__batch_size)
self.__batch_count = 0
示例5: main
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def main():
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
sess = tf.Session()
batch_size = cfg.BATCH_SIZE
parameter_path = cfg.PARAMETER_FILE
lenet = Lenet()
max_iter = cfg.MAX_ITER
saver = tf.train.Saver()
if os.path.exists(parameter_path):
saver.restore(parameter_path)
else:
sess.run(tf.initialize_all_variables())
for i in range(max_iter):
batch = mnist.train.next_batch(50)
if i % 100 == 0:
train_accuracy = sess.run(lenet.train_accuracy,feed_dict={
lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1]
})
print("step %d, training accuracy %g" % (i, train_accuracy))
sess.run(lenet.train_op,feed_dict={lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1]})
save_path = saver.save(sess, parameter_path)
示例6: setup
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def setup(self, bottom, top):
"""Setup the DataLayer."""
if cfg.TRIPLET_LOSS:
self.batch_size = cfg.TRIPLET_BATCH_SIZE
else:
self.batch_size = cfg.BATCH_SIZE
self._name_to_top_map = {
'data': 0,
'labels': 1}
self._index = 0
self._epoch = 1
# data blob: holds a batch of N images, each with 3 channels
# The height and width (100 x 100) are dummy values
top[0].reshape(self.batch_size, 3, 224, 224)
top[1].reshape(self.batch_size)
示例7: setup
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def setup(self, bottom, top):
"""Setup the TripletSelectLayer."""
self.triplet = config.BATCH_SIZE/3
top[0].reshape(self.triplet,shape(bottom[0].data)[1])
top[1].reshape(self.triplet,shape(bottom[0].data)[1])
top[2].reshape(self.triplet,shape(bottom[0].data)[1])
示例8: showProgress
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def showProgress(epoch, done=False):
global last_update
# First call?
if not 'batch_count' in cfg.STATS:
bcnt = 0
else:
bcnt = cfg.STATS['batch_count']
# Calculate number of batches to train
total_batches = cfg.STATS['sample_count'] // cfg.BATCH_SIZE + 1
# Current progess
if not done:
if bcnt == 0:
log.p(('EPOCH', epoch, '['), new_line=False)
else:
p = bcnt * 100 / total_batches
if not p % 5 and not p == last_update:
log.p('=', new_line=False)
last_update = p
else:
log.p(']', new_line=False)
# Clear on first load
示例9: getDatasetChunk
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def getDatasetChunk(split):
#get batch-sized chunks of image paths
for i in xrange(0, len(split), cfg.BATCH_SIZE):
yield split[i:i+cfg.BATCH_SIZE]
示例10: getNextImageBatch
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def getNextImageBatch(split, augmentation=True):
#fill batch
for chunk in getDatasetChunk(split):
#allocate numpy arrays for image data and targets
x_b = np.zeros((cfg.BATCH_SIZE, cfg.IM_DIM, cfg.IM_SIZE[1], cfg.IM_SIZE[0]), dtype='float32')
y_b = np.zeros((cfg.BATCH_SIZE, len(cfg.CLASSES)), dtype='float32')
ib = 0
for sample in chunk:
try:
#load image data and class label from path
x, y = loadImageAndTarget(sample, augmentation)
#pack into batch array
x_b[ib] = x
y_b[ib] = y
ib += 1
except:
continue
#trim to actual size
x_b = x_b[:ib]
y_b = y_b[:ib]
#instead of return, we use yield
yield x_b, y_b
#Loading images with CPU background threads during GPU forward passes saves a lot of time
#Credit: J. Schlüter (https://github.com/Lasagne/Lasagne/issues/12)
示例11: generate_datasets
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def generate_datasets():
train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)
valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)
test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)
# read the original_dataset in the form of batch
train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)
valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)
test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)
return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
示例12: placeholder_inputs
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def placeholder_inputs(config):
pc_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3))
color_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_POINT, 3))
pc_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, config.NUM_POINT_INS, 3))
group_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT))
group_indicator_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_GROUP))
seg_label_pl = tf.placeholder(tf.int32, shape=(config.BATCH_SIZE, config.NUM_POINT))
bbox_ins_pl = tf.placeholder(tf.float32, shape=(config.BATCH_SIZE, config.NUM_GROUP, 6))
return pc_pl, color_pl, pc_ins_pl, group_label_pl, group_indicator_pl, seg_label_pl, bbox_ins_pl
示例13: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, rebuild=False, data_aug=False,
multithread=False, batch_size=cfg.BATCH_SIZE,
image_size = cfg.IMAGE_SIZE, RGB=False):
self.name = 'ilsvrc_2017_cls'
self.devkit_path = cfg.ILSVRC_PATH
self.data_path = self.devkit_path
self.cache_path = cfg.CACHE_PATH
self.batch_size = batch_size
self.image_size = image_size
self.image_set = image_set
self.rebuild = rebuild
self.multithread = multithread
self.data_aug = data_aug
self.RGB = RGB
self.load_classes()
self.cursor = 0
self.epoch = 1
self.gt_labels = None
assert os.path.exists(self.devkit_path), \
'ILSVRC path does not exist: {}'.format(self.devkit_path)
assert os.path.exists(self.data_path), \
'Path does not exist: {}'.format(self.data_path)
self.prepare()
if self.multithread:
self.prepare_multithread()
self.get = self._get_multithread
else:
self.get = self._get
示例14: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, batch_size=cfg.BATCH_SIZE, rebuild=False):
self.name = 'voc_2007'
self.devkit_path = cfg.PASCAL_PATH
self.data_path = os.path.join(self.devkit_path, 'VOC2007')
self.cache_path = cfg.CACHE_PATH
self.batch_size = batch_size
self.image_size = cfg.IMAGE_SIZE
self.cell_size = cfg.S
self.classes = ('aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
self.num_class = len(self.classes)
self.class_to_ind = dict(
list(zip(self.classes, list(range(self.num_class)))))
self.flipped = cfg.FLIPPED
self.image_set = image_set
self.rebuild = rebuild
self.cursor = 0
self.gt_labels = None
assert os.path.exists(self.devkit_path), \
'VOCdevkit path does not exist: {}'.format(self.devkit_path)
assert os.path.exists(self.data_path), \
'Path does not exist: {}'.format(self.data_path)
self.prepare()
示例15: __init__
# 需要导入模块: import config [as 别名]
# 或者: from config import BATCH_SIZE [as 别名]
def __init__(self, image_set, rebuild=False,
multithread=False, batch_size=cfg.BATCH_SIZE,
image_size = cfg.IMAGE_SIZE, random_noise=False):
self.name = 'ilsvrc_2017_cls'
self.devkit_path = cfg.ILSVRC_PATH
self.data_path = self.devkit_path
self.cache_path = cfg.CACHE_PATH
self.batch_size = batch_size
self.image_size = image_size
self.image_set = image_set
self.rebuild = rebuild
self.multithread = multithread
self.random_noise = random_noise
self.load_classes()
self.cursor = 0
self.epoch = 1
self.gt_labels = None
assert os.path.exists(self.devkit_path), \
'ILSVRC path does not exist: {}'.format(self.devkit_path)
assert os.path.exists(self.data_path), \
'Path does not exist: {}'.format(self.data_path)
self.prepare()
if self.multithread:
self.prepare_multithread()
self.get = self._get_multithread
else:
self.get = self._get