本文整理汇总了Python中data.DataProvider.get_next_batch方法的典型用法代码示例。如果您正苦于以下问题:Python DataProvider.get_next_batch方法的具体用法?Python DataProvider.get_next_batch怎么用?Python DataProvider.get_next_batch使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data.DataProvider
的用法示例。
在下文中一共展示了DataProvider.get_next_batch方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from data import DataProvider [as 别名]
# 或者: from data.DataProvider import get_next_batch [as 别名]
class Trainer:
CHECKPOINT_REGEX = None
def __init__(self, test_id, data_dir, checkpoint_dir, train_range, test_range, test_freq,
save_freq, batch_size, num_epoch, image_size, image_color, learning_rate, n_out,
autoInit=True, adjust_freq = 1, factor = 1.0):
self.test_id = test_id
self.data_dir = data_dir
self.checkpoint_dir = checkpoint_dir
self.train_range = train_range
self.test_range = test_range
self.test_freq = test_freq
self.save_freq = save_freq
self.batch_size = batch_size
self.num_epoch = num_epoch
self.image_size = image_size
self.image_color = image_color
self.learning_rate = learning_rate
self.n_out = n_out
self.factor = factor
self.adjust_freq = adjust_freq
self.regex = re.compile('^test%d-(\d+)\.(\d+)$' % self.test_id)
self.init_data_provider()
self.image_shape = (self.batch_size, self.image_color, self.image_size, self.image_size)
self.train_outputs = []
self.test_outputs = []
self.net = FastNet(self.learning_rate, self.image_shape, self.n_out, autoAdd = autoInit)
self.num_batch = self.curr_epoch = self.curr_batch = 0
self.train_data = None
self.test_data = None
self.num_train_minibatch = 0
self.num_test_minibatch = 0
self.checkpoint_file = ''
def init_data_provider(self):
self.train_dp = DataProvider(self.batch_size, self.data_dir, self.train_range)
self.test_dp = DataProvider(self.batch_size, self.data_dir, self.test_range)
def get_next_minibatch(self, i, train = TRAIN):
if train == TRAIN:
num = self.num_train_minibatch
data = self.train_data
else:
num = self.num_test_minibatch
data = self.test_data
batch_data = data['data']
batch_label = data['labels']
batch_size = self.batch_size
if i == num -1:
input = batch_data[:, i * batch_size: -1]
label = batch_label[i* batch_size : -1]
else:
input = batch_data[:, i * batch_size: (i +1)* batch_size]
label = batch_label[i * batch_size: (i + 1) * batch_size]
return input, label
def save_checkpoint(self):
model = {}
model['batchnum'] = self.train_dp.get_batch_num()
model['epoch'] = self.num_epoch + 1
model['layers'] = self.net.get_dumped_layers()
model['train_outputs'] = self.train_outputs
model['test_outputs'] = self.test_outputs
dic = {'model_state': model, 'op':None}
saved_filename = [f for f in os.listdir(self.checkpoint_dir) if self.regex.match(f)]
for f in saved_filename:
os.remove(os.path.join(self.checkpoint_dir, f))
checkpoint_filename = "test%d-%d.%d" % (self.test_id, self.curr_epoch, self.curr_batch)
checkpoint_file_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
self.checkpoint_file = checkpoint_file_path
print checkpoint_file_path
with open(checkpoint_file_path, 'w') as f:
cPickle.dump(dic, f)
def get_test_error(self):
start = time.time()
_, _, self.test_data = self.test_dp.get_next_batch()
self.num_test_minibatch = ceil(self.test_data['data'].shape[1], self.batch_size)
for i in range(self.num_test_minibatch):
input, label = self.get_next_minibatch(i, TEST)
label = np.array(label).astype(np.float32)
label.shape = (label.size, 1)
self.net.train_batch(input, label, TEST)
cost , correct, numCase, = self.net.get_batch_information()
self.test_outputs += [({'logprob': [cost, 1-correct]}, numCase, time.time() - start)]
print 'error: %f logreg: %f time: %f' % (1-correct, cost, time.time() -
start)
def check_continue_trainning(self):
return self.curr_epoch <= self.num_epoch
#.........这里部分代码省略.........