本文整理汇总了Python中tensorflow.contrib.slim.nets.resnet_v2.resnet_v2_50方法的典型用法代码示例。如果您正苦于以下问题:Python resnet_v2.resnet_v2_50方法的具体用法?Python resnet_v2.resnet_v2_50怎么用?Python resnet_v2.resnet_v2_50使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.slim.nets.resnet_v2
的用法示例。
在下文中一共展示了resnet_v2.resnet_v2_50方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from tensorflow.contrib.slim.nets import resnet_v2 [as 别名]
# 或者: from tensorflow.contrib.slim.nets.resnet_v2 import resnet_v2_50 [as 别名]
def train(self):
img_size = [self.image_height, self.image_width, self.image_depth]
train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
batch_size = self.train_batch_size,
capacity = 3000,
num_threads = 2,
min_after_dequeue = 1000)
test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
batch_size = self.test_batch_size,
capacity = 500,
num_threads = 2,
min_after_dequeue = 300)
init = tf.global_variables_initializer()
init_fn = slim.assign_from_checkpoint_fn("resnet_v2_50.ckpt", slim.get_model_variables('resnet_v2'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
init_fn(sess)
train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
test_writer = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
for iter in range(self.max_iteration):
inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
# train with dynamic learning rate
if iter <= 500:
self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.learning_rate:1e-3, self.batch_size:self.train_batch_size})
elif iter <= self.max_iteration - 1000:
self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.learning_rate:0.5e-3, self.batch_size:self.train_batch_size})
else:
self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.learning_rate:1e-4, self.batch_size:self.train_batch_size})
# print training loss and test loss
if iter%10 == 0:
summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train,
self.batch_size:self.train_batch_size})
train_writer.add_summary(summary_train, iter)
train_writer.flush()
summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test,
self.batch_size:self.test_batch_size})
test_writer.add_summary(summary_test, iter)
test_writer.flush()
# record training loss and test loss
if iter%10 == 0:
train_loss = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.batch_size:self.train_batch_size})
test_loss = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test,
self.batch_size:self.test_batch_size})
print("iter step %d trainning batch loss %f"%(iter, train_loss))
print("iter step %d test loss %f\n"%(iter, test_loss))
# record model
if iter%100 == 0:
saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
coord.request_stop()
coord.join(threads)
示例2: build
# 需要导入模块: from tensorflow.contrib.slim.nets import resnet_v2 [as 别名]
# 或者: from tensorflow.contrib.slim.nets.resnet_v2 import resnet_v2_50 [as 别名]
def build(self, images):
"""Builds a ResNet50 embedder for the input images.
It assumes that the range of the pixel values in the images tensor is
[0,255] and should be castable to tf.uint8.
Args:
images: a tensor that contains the input images which has the shape of
NxTxHxWx3 where N is the batch size, T is the maximum length of the
sequence, H and W are the height and width of the images and C is the
number of channels.
Returns:
The embedding of the input image with the shape of NxTxL where L is the
embedding size of the output.
Raises:
ValueError: if the shape of the input does not agree with the expected
shape explained in the Args section.
"""
shape = images.get_shape().as_list()
if len(shape) != 5:
raise ValueError(
'The tensor shape should have 5 elements, {} is provided'.format(
len(shape)))
if shape[4] != 3:
raise ValueError('Three channels are expected for the input image')
images = tf.cast(images, tf.uint8)
images = tf.reshape(images,
[shape[0] * shape[1], shape[2], shape[3], shape[4]])
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
def preprocess_fn(x):
x = tf.expand_dims(x, 0)
x = tf.image.resize_bilinear(x, [299, 299],
align_corners=False)
return(tf.squeeze(x, [0]))
images = tf.map_fn(preprocess_fn, images, dtype=tf.float32)
net, _ = resnet_v2.resnet_v2_50(
images, is_training=False, global_pool=True)
output = tf.reshape(net, [shape[0], shape[1], -1])
return output