本文整理汇总了Python中provider.jitter_point_cloud方法的典型用法代码示例。如果您正苦于以下问题:Python provider.jitter_point_cloud方法的具体用法?Python provider.jitter_point_cloud怎么用?Python provider.jitter_point_cloud使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类provider
的用法示例。
在下文中一共展示了provider.jitter_point_cloud方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
num_batches = 0
total_correct = 0
count = 0
loss_f = nn.CrossEntropyLoss()
with tqdm.tqdm(train_loader, ascii=True) as tq:
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
label = label[:, 0]
num_examples = label.shape[0]
data, label = data.to(dev), label.to(dev).squeeze().long()
opt.zero_grad()
logits = net(data)
loss = loss_f(logits, label)
loss.backward()
opt.step()
_, preds = logits.max(1)
num_batches += 1
count += num_examples
loss = loss.item()
correct = (preds == label).sum().item()
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
scheduler.step()
示例2: get_example
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def get_example(self, i):
"""Return i-th data"""
if self.augment:
rotated_data = provider.rotate_point_cloud(
self.data[i:i + 1, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
point_data = jittered_data[0]
else:
point_data = self.data[i]
# pint_data (2048, 3): (num_point, k) --> convert to (k, num_point, 1)
point_data = np.transpose(
point_data.astype(np.float32), (1, 0))[:, :, None]
assert point_data.dtype == np.float32
assert self.label[i].dtype == np.int32
return point_data, self.label[i]
示例3: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
# Shuffle train samples
train_idxs = np.arange(0, len(TRAIN_DATASET))
np.random.shuffle(train_idxs)
num_batches = len(TRAIN_DATASET)/BATCH_SIZE
log_string(str(datetime.now()))
total_correct = 0
total_seen = 0
loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
batch_data, batch_label = get_batch(TRAIN_DATASET, train_idxs, start_idx, end_idx)
# Augment batched point clouds by rotation and jittering
#aug_data = batch_data
#aug_data = provider.random_scale_point_cloud(batch_data)
batch_data[:,:,0:3] = provider.jitter_point_cloud(batch_data[:,:,0:3])
feed_dict = {ops['pointclouds_pl']: batch_data,
ops['labels_pl']: batch_label,
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 2)
correct = np.sum(pred_val == batch_label)
total_correct += correct
total_seen += (BATCH_SIZE*NUM_POINT)
loss_sum += loss_val
if (batch_idx+1)%10 == 0:
log_string(' -- %03d / %03d --' % (batch_idx+1, num_batches))
log_string('mean loss: %f' % (loss_sum / 10))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
total_correct = 0
total_seen = 0
loss_sum = 0
示例4: _augment_batch_data
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def _augment_batch_data(self, batch_data):
rotated_data = provider.rotate_point_cloud(batch_data)
rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
jittered_data = provider.shift_point_cloud(jittered_data)
jittered_data = provider.jitter_point_cloud(jittered_data)
rotated_data[:,:,0:3] = jittered_data
return provider.shuffle_points(rotated_data)
示例5: _augment_batch_data
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def _augment_batch_data(self, batch_data):
if self.normal_channel:
rotated_data = provider.rotate_point_cloud_with_normal(batch_data)
rotated_data = provider.rotate_perturbation_point_cloud_with_normal(rotated_data)
else:
rotated_data = provider.rotate_point_cloud(batch_data)
rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
jittered_data = provider.shift_point_cloud(jittered_data)
jittered_data = provider.jitter_point_cloud(jittered_data)
rotated_data[:,:,0:3] = jittered_data
return provider.shuffle_points(rotated_data)
示例6: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
current_data, current_label, current_parts = data_utils.get_current_data_parts_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_PARTS, NUM_POINT)
current_label = np.squeeze(current_label)
current_parts = np.squeeze(current_parts)
num_batches = current_data.shape[0]//BATCH_SIZE
total_seen = 0
loss_sum = 0
total_correct_seg = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
# Augment batched point clouds by rotation and jittering
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
feed_dict = {ops['pointclouds_pl']: jittered_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['parts_pl']: current_parts[start_idx:end_idx],
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, seg_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['seg_pred']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
seg_val = np.argmax(seg_val, 2)
seg_correct = np.sum(seg_val == current_parts[start_idx:end_idx])
total_correct_seg += seg_correct
total_seen += BATCH_SIZE
loss_sum += loss_val
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT)))
示例7: get_kmeans_init
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def get_kmeans_init(n_gaussians, cov_type='farthest'):
D = 3
# Get the training data for initialization
# Load multiple models from the dataset
points, labels, _, _ = provider.load_dataset( num_points=1024)
mask = []
for i in range(40):
mask.append(np.squeeze(np.where(labels == i))[0:10])
mask = np.concatenate(mask, axis=0)
points = points[mask, :, :]
points = provider.jitter_point_cloud(points, sigma=0.01, clip=0.05)
points = np.concatenate(points, axis=0)
#input function for kmeans clustering
def input_fn():
return tf.constant(points, dtype=tf.float32), None
## construct model
kmeans = tf.contrib.learn.KMeansClustering(num_clusters=n_gaussians, relative_tolerance=0.0001)
kmeans.fit(input_fn=input_fn)
centers = kmeans.clusters()
assignments = np.squeeze(list(kmeans.predict_cluster_idx(input_fn=input_fn)))
n_points = points.shape[0]
stdev = []
w = []
for i in range(n_gaussians):
idx = np.squeeze(np.where(assignments == i))
w.append(len(idx) / n_points)
if cov_type == 'compute_cov':
samples = points[idx, :].T
stdev.append(np.sqrt(np.diag(np.cov(samples))))
elif cov_type == 'farthest':
d = np.sqrt(np.sum(np.power(points[idx, :] - centers[i, :], 2), axis=1))
farthest_point_idx = np.argmax(d)
stdev.append((np.max(d) / 3.) * np.ones(D))
# gmm = GaussianMixture(n_components=n_gaussians, covariance_type='diag')
return w, centers.T, np.array(stdev, dtype=np.float32).T
示例8: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
if (".h5" in TRAIN_FILE):
current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
else:
current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
current_label = np.squeeze(current_label)
num_batches = current_data.shape[0]//BATCH_SIZE
total_correct = 0
total_seen = 0
loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
# Augment batched point clouds by rotation and jittering
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
feed_dict = {ops['pointclouds_pl']: jittered_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 1)
correct = np.sum(pred_val == current_label[start_idx:end_idx])
total_correct += correct
total_seen += BATCH_SIZE
loss_sum += loss_val
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
示例9: get_batch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def get_batch(self, data_aug=False):
data, sem_label, ins_label = self.data_queue.get()
if data_aug and self.split == 'train':
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
return data, sem_label, ins_label
示例10: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer, epoch):
""" ops: dict mapping from string to tf ops """
is_training = True
batch_size = cfg.training.batch_size
train_idxs = copy.deepcopy(TRAIN_INDICES)
np.random.shuffle(train_idxs)
num_batches = len(train_idxs) // batch_size
loss_sum = 0
pbar = tqdm(range(num_batches), desc=f'train', postfix=dict(last_loss_str=''))
for batch_idx in pbar:
# logger.info('----- batch ' + str(batch_idx) + ' -----')
start_idx = batch_idx * batch_size
end_idx = (batch_idx + 1) * batch_size
pcs1, pcs2, translations, rel_angles, pc1centers, pc2centers, pc1angles, pc2angles = provider.load_batch(train_idxs[start_idx:end_idx])
# Augment batched point clouds by jittering
pcs1 = provider.jitter_point_cloud(pcs1)
pcs2 = provider.jitter_point_cloud(pcs2)
feed_dict = {
ops['pcs1']: pcs1,
ops['pcs2']: pcs2,
ops['translations']: translations,
ops['rel_angles']: rel_angles,
ops['is_training_pl']: is_training,
ops['pc1centers']: pc1centers,
ops['pc2centers']: pc2centers,
ops['pc1angles']: pc1angles,
ops['pc2angles']: pc2angles,
}
summary, step, _, loss_val, pred_translations, pred_remaining_angle_logits = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred_translations'], ops['pred_remaining_angle_logits']], feed_dict=feed_dict)
# step_in_epochs = float(epoch) + float(end_idx / len(train_idxs))
train_writer.add_summary(summary, step)
# pred_val = np.argmax(pred_val, 1)
# correct = np.sum(pred_val == current_label[start_idx:end_idx])
# total_correct += correct
# total_seen += cfg.training.batch_size
loss_sum += loss_val
pbar.set_postfix(last_loss_str=f'{loss_val:.5f}')
# if batch_idx == 0:
# logger.info(np.concatenate([pred_val, transforms], axis=1)[:5,:])
logger.info('train mean loss: %f' % (loss_sum / float(num_batches)))
# logger.info('accuracy: %f' % (total_correct / float(total_seen)))
train_writer.flush()
示例11: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
# Shuffle train files
train_file_idxs = np.arange(0, len(TRAIN_FILES))
np.random.shuffle(train_file_idxs)
for fn in range(len(TRAIN_FILES)):
log_string('----' + str(fn) + '-----')
current_data, current_label, _ = provider.loadDataFile_with_normal(TRAIN_FILES[train_file_idxs[fn]])
current_data = current_data[:,0:NUM_POINT,:]
current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))
current_label = np.squeeze(current_label)
file_size = current_data.shape[0]
num_batches = file_size // BATCH_SIZE
total_correct = 0
total_seen = 0
loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
# Augment batched point clouds by rotation and jittering
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
feed_dict = {ops['pointclouds_pl']: jittered_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 1)
correct = np.sum(pred_val == current_label[start_idx:end_idx])
total_correct += correct
total_seen += BATCH_SIZE
loss_sum += loss_val
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
示例12: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
# Shuffle train samples
train_idxs = np.arange(0, len(TRAIN_DATASET))
np.random.shuffle(train_idxs)
num_batches = len(TRAIN_DATASET)/BATCH_SIZE
log_string(str(datetime.now()))
total_correct = 0
total_seen = 0
loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
batch_data, batch_label, batch_cls_label = get_batch(TRAIN_DATASET, train_idxs, start_idx, end_idx)
# Augment batched point clouds by rotation and jittering
#aug_data = batch_data
#aug_data = provider.random_scale_point_cloud(batch_data)
batch_data[:,:,0:3] = provider.jitter_point_cloud(batch_data[:,:,0:3])
feed_dict = {ops['pointclouds_pl']: batch_data,
ops['labels_pl']: batch_label,
ops['cls_labels_pl']: batch_cls_label,
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 2)
correct = np.sum(pred_val == batch_label)
total_correct += correct
total_seen += (BATCH_SIZE*NUM_POINT)
loss_sum += loss_val
if (batch_idx+1)%10 == 0:
log_string(' -- %03d / %03d --' % (batch_idx+1, num_batches))
log_string('mean loss: %f' % (loss_sum / 10))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
total_correct = 0
total_seen = 0
loss_sum = 0
示例13: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
# Shuffle train files
train_file_idxs = np.arange(0, len(TRAIN_FILES))
np.random.shuffle(train_file_idxs)
for fn in range(len(TRAIN_FILES)):
log_string('----' + str(fn) + '-----')
current_data, current_label, normal_data = provider.loadDataFile_with_normal(TRAIN_FILES[train_file_idxs[fn]])
normal_data = normal_data[:,0:NUM_POINT,:]
current_data = current_data[:,0:NUM_POINT,:]
current_data, current_label, shuffle_idx = provider.shuffle_data(current_data, np.squeeze(current_label))
current_label = np.squeeze(current_label)
normal_data = normal_data[shuffle_idx, ...]
file_size = current_data.shape[0]
num_batches = file_size // BATCH_SIZE
total_correct = 0
total_seen = 0
loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
# Augment batched point clouds by rotation and jittering
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
input_data = np.concatenate((jittered_data, normal_data[start_idx:end_idx, :, :]), 2)
#random point dropout
input_data = provider.random_point_dropout(input_data)
feed_dict = {ops['pointclouds_pl']: input_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 1)
correct = np.sum(pred_val == current_label[start_idx:end_idx])
total_correct += correct
total_seen += BATCH_SIZE
loss_sum += loss_val
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
示例14: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
current_data, current_label, current_mask = data_utils.get_current_data_withmask_h5(TRAIN_DATA, TRAIN_LABELS, TRAIN_MASKS, NUM_POINT)
current_label = np.squeeze(current_label)
current_mask = np.squeeze(current_mask)
num_batches = current_data.shape[0]//BATCH_SIZE
total_correct = 0
total_seen = 0
loss_sum = 0
total_correct_seg = 0
classify_loss_sum = 0
seg_loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
# Augment batched point clouds by rotation and jittering
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
feed_dict = {ops['pointclouds_pl']: jittered_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['masks_pl']: current_mask[start_idx:end_idx],
ops['is_training_pl']: is_training,}
summary, step, _, loss_val, pred_val, seg_val, classify_loss, seg_loss = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred'], ops['seg_pred'], ops['classify_loss'], ops['seg_loss']], feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 1)
correct = np.sum(pred_val == current_label[start_idx:end_idx])
seg_val = np.argmax(seg_val, 2)
seg_correct = np.sum(seg_val == current_mask[start_idx:end_idx])
total_correct_seg += seg_correct
total_correct += correct
total_seen += BATCH_SIZE
loss_sum += loss_val
classify_loss_sum += classify_loss
seg_loss_sum += seg_loss
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
log_string('classify mean loss: %f' % (classify_loss_sum / float(num_batches)))
log_string('seg mean loss: %f' % (seg_loss_sum / float(num_batches)))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
log_string('seg accuracy: %f' % (total_correct_seg / (float(total_seen)*NUM_POINT)))
示例15: train_one_epoch
# 需要导入模块: import provider [as 别名]
# 或者: from provider import jitter_point_cloud [as 别名]
def train_one_epoch(sess, ops, gmm, train_writer):
""" ops: dict mapping from string to tf ops """
is_training = True
if (".h5" in TRAIN_FILE):
current_data, current_label = data_utils.get_current_data_h5(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
else:
current_data, current_label = data_utils.get_current_data(TRAIN_DATA, TRAIN_LABELS, NUM_POINT)
current_label = np.squeeze(current_label)
num_batches = current_data.shape[0]//BATCH_SIZE
total_correct = 0
total_seen = 0
loss_sum = 0
for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx + 1) * BATCH_SIZE
# Augment batched point clouds by rotation and jittering
augmented_data = current_data[start_idx:end_idx, :, :]
if augment_scale:
augmented_data = provider.scale_point_cloud(augmented_data, smin=0.66, smax=1.5)
if augment_rotation:
augmented_data = provider.rotate_point_cloud(augmented_data)
if augment_translation:
augmented_data = provider.translate_point_cloud(augmented_data, tval = 0.2)
if augment_jitter:
augmented_data = provider.jitter_point_cloud(augmented_data, sigma=0.01,
clip=0.05) # default sigma=0.01, clip=0.05
if augment_outlier:
augmented_data = provider.insert_outliers_to_point_cloud(augmented_data, outlier_ratio=0.02)
feed_dict = {ops['points_pl']: augmented_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['w_pl']: gmm.weights_,
ops['mu_pl']: gmm.means_,
ops['sigma_pl']: np.sqrt(gmm.covariances_),
ops['is_training_pl']: is_training, }
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
ops['train_op'], ops['loss'], ops['pred']],
feed_dict=feed_dict)
train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 1)
correct = np.sum(pred_val == current_label[start_idx:end_idx])
total_correct += correct
total_seen += BATCH_SIZE
loss_sum += loss_val
log_string('mean loss: %f' % (loss_sum / float(num_batches)))
log_string('accuracy: %f' % (total_correct / float(total_seen)))