本文整理汇总了Python中tensorflow.contrib.slim.get_variables_to_restore方法的典型用法代码示例。如果您正苦于以下问题:Python slim.get_variables_to_restore方法的具体用法?Python slim.get_variables_to_restore怎么用?Python slim.get_variables_to_restore使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.slim
的用法示例。
在下文中一共展示了slim.get_variables_to_restore方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def main():
args = parse_args()
with tf.Session(graph=tf.Graph()) as session:
input_var = tf.placeholder(
tf.uint8, (None, 128, 64, 3), name="images")
image_var = tf.map_fn(
lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
back_prop=False)
factory_fn = _network_factory()
features, _ = factory_fn(image_var, reuse=None)
features = tf.identity(features, name="features")
saver = tf.train.Saver(slim.get_variables_to_restore())
saver.restore(session, args.checkpoint_in)
output_graph_def = tf.graph_util.convert_variables_to_constants(
session, tf.get_default_graph().as_graph_def(),
[features.name.split(":")[0]])
with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
file_handle.write(output_graph_def.SerializeToString())
示例2: build_pretrained_graph
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def build_pretrained_graph(
self, images, resnet_layer, checkpoint, is_training, reuse=False):
"""See baseclass."""
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
_, endpoints = resnet_v2.resnet_v2_50(
images, is_training=is_training, reuse=reuse)
resnet_layer = 'resnet_v2_50/block%d' % resnet_layer
resnet_output = endpoints[resnet_layer]
resnet_variables = slim.get_variables_to_restore()
resnet_variables = [
i for i in resnet_variables if 'global_step' not in i.name]
if is_training and not reuse:
init_saver = tf.train.Saver(resnet_variables)
def init_fn(scaffold, sess):
del scaffold
init_saver.restore(sess, checkpoint)
else:
init_fn = None
return resnet_output, resnet_variables, init_fn
示例3: restore
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def restore(self, path = None, restore_opt = True, ul_only = False):
if path is None:
path = tf.train.latest_checkpoint(self.pr.train_dir)
print 'Restoring:', path
var_list = slim.get_variables_to_restore()
for x in var_list:
print x.name
print
var_list = slim.get_variables_to_restore()
if not restore_opt:
opt_names = ['Adam', 'beta1_power', 'beta2_power', 'Momentum'] + ['cls']# + ['renorm_mean_weight', 'renorm_stddev_weight', 'moving_mean', 'renorm']
print 'removing bn gamma'
opt_names += ['gamma']
var_list = [x for x in var_list if not any(name in x.name for name in opt_names)]
if ul_only:
var_list = [x for x in var_list if not x.name.startswith('lb/') and ('global_step' not in x.name)]
#var_list = [x for x in var_list if ('global_step' not in x.name)]
print 'Restoring variables:'
for x in var_list:
print x.name
tf.train.Saver(var_list).restore(self.sess, path)
# print 'TEST: restoring all'
# tf.train.Saver().restore(self.sess, path)
示例4: restore
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def restore(self, path = None, restore_opt = True, restore_resnet18_blocks = True, restore_dilation_blocks = True):
if path is None:
path = tf.train.latest_checkpoint(self.pr.train_dir)
print 'Restoring from:', path
var_list = slim.get_variables_to_restore()
opt_names = ['Adam', 'beta1_power', 'beta2_power', 'Momentum', 'cache']
if not restore_resnet18_blocks:
opt_names += ['conv2_2_', 'conv3_2_', 'conv4_2_', 'conv5_2_']
if not restore_opt:
var_list = [x for x in var_list if not any(name in x.name for name in opt_names)]
print 'Restoring:'
for x in var_list:
print x.name
print
tf.train.Saver(var_list).restore(self.sess, path)
示例5: variables_to_restore
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def variables_to_restore(scope=None, strip_scope=False):
"""Returns a list of variables to restore for the specified list of methods.
It is supposed that variable name starts with the method's scope (a prefix
returned by _method_scope function).
Args:
methods_names: a list of names of configurable methods.
strip_scope: if True will return variable names without method's scope.
If methods_names is None will return names unchanged.
model_scope: a scope for a whole model.
Returns:
a dictionary mapping variable names to variables for restore.
"""
if scope:
variable_map = {}
method_variables = slim.get_variables_to_restore(include=[scope])
for var in method_variables:
if strip_scope:
var_name = var.op.name[len(scope) + 1:]
else:
var_name = var.op.name
variable_map[var_name] = var
return variable_map
else:
return {v.op.name: v for v in slim.get_variables_to_restore()}
示例6: get_repr_from_image
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_repr_from_image(images_reshaped, modalities, data_augment, encoder,
freeze_conv, wt_decay, is_training):
# Pass image through lots of convolutional layers, to obtain pool5
if modalities == ['rgb']:
with tf.name_scope('pre_rgb'):
x = (images_reshaped + 128.) / 255. # Convert to brightness between 0 and 1.
if data_augment.relight and is_training:
x = tf_utils.distort_image(x, fast_mode=data_augment.relight_fast)
x = (x-0.5)*2.0
scope_name = encoder
elif modalities == ['depth']:
with tf.name_scope('pre_d'):
d_image = images_reshaped
x = 2*(d_image[...,0] - 80.0)/100.0
y = d_image[...,1]
d_image = tf.concat([tf.expand_dims(x, -1), tf.expand_dims(y, -1)], 3)
x = d_image
scope_name = 'd_'+encoder
resnet_is_training = is_training and (not freeze_conv)
with slim.arg_scope(resnet_v2.resnet_utils.resnet_arg_scope(resnet_is_training)):
fn = getattr(tf_utils, encoder)
x, end_points = fn(x, num_classes=None, global_pool=False,
output_stride=None, reuse=None,
scope=scope_name)
vars_ = slim.get_variables_to_restore()
conv_feat = x
return conv_feat, vars_
示例7: _image_to_head
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def _image_to_head(self, is_training, reuse=None):
with slim.arg_scope(mobilenet_v2.training_scope(is_training=is_training)):
net, endpoints = mobilenet_v2.mobilenet_base(self._image, conv_defs=CTPN_DEF)
self.variables_to_restore = slim.get_variables_to_restore()
self._act_summaries.append(net)
self._layers['head'] = net
return net
示例8: get_variables_to_restore
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_variables_to_restore(self, variables, var_keep_dic):
pass
示例9: get_model_init_fn
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_model_init_fn(train_logdir,
tf_initial_checkpoint,
initialize_last_layer,
last_layers,
ignore_missing_vars=False):
"""Gets the function initializing model variables from a checkpoint.
Args:
train_logdir: Log directory for training.
tf_initial_checkpoint: TensorFlow checkpoint for initialization.
initialize_last_layer: Initialize last layer or not.
last_layers: Last layers of the model.
ignore_missing_vars: Ignore missing variables in the checkpoint.
Returns:
Initialization function.
"""
if tf_initial_checkpoint is None:
tf.logging.info('Not initializing the model from a checkpoint.')
return None
if tf.train.latest_checkpoint(train_logdir):
tf.logging.info('Ignoring initialization; other checkpoint exists')
return None
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored.
exclude_list = ['global_step']
if not initialize_last_layer:
exclude_list.extend(last_layers)
variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)
if variables_to_restore:
return slim.assign_from_checkpoint_fn(
tf_initial_checkpoint,
variables_to_restore,
ignore_missing_vars=ignore_missing_vars)
return None
示例10: build_inceptionv3_graph
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def build_inceptionv3_graph(images, endpoint, is_training, checkpoint,
reuse=False):
"""Builds an InceptionV3 model graph.
Args:
images: A 4-D float32 `Tensor` of batch images.
endpoint: String, name of the InceptionV3 endpoint.
is_training: Boolean, whether or not to build a training or inference graph.
checkpoint: String, path to the pretrained model checkpoint.
reuse: Boolean, whether or not we are reusing the embedder.
Returns:
inception_output: `Tensor` holding the InceptionV3 output.
inception_variables: List of inception variables.
init_fn: Function to initialize the weights (if not reusing, then None).
"""
with slim.arg_scope(inception.inception_v3_arg_scope()):
_, endpoints = inception.inception_v3(
images, num_classes=1001, is_training=is_training)
inception_output = endpoints[endpoint]
inception_variables = slim.get_variables_to_restore()
inception_variables = [
i for i in inception_variables if 'global_step' not in i.name]
if is_training and not reuse:
init_saver = tf.train.Saver(inception_variables)
def init_fn(scaffold, sess):
del scaffold
init_saver.restore(sess, checkpoint)
else:
init_fn = None
return inception_output, inception_variables, init_fn
示例11: load_ckpt
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def load_ckpt(sess, model_dir, variables_to_restore=None):
ckpt = tf.train.get_checkpoint_state(model_dir)
model_path = ckpt.model_checkpoint_path
if variables_to_restore is None:
variables_to_restore = slim.get_variables_to_restore()
restore_op, restore_fd = slim.assign_from_checkpoint(
model_path, variables_to_restore)
sess.run(restore_op, feed_dict=restore_fd)
print(f'{model_path} loaded')
示例12: __init__
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def __init__(self, sess, ontar_model_dir, is_reg=False, seq_feature_only=False):
self.sess = sess
if seq_feature_only:
self.inputs_sg = tf.placeholder(dtype=tf.float32, shape=[None, 1, 23, 4])
else:
self.inputs_sg = tf.placeholder(dtype=tf.float32, shape=[None, 1, 23, 8])
if is_reg:
self.pred_ontar = build_ontar_reg_model(self.inputs_sg)
else:
self.pred_ontar = build_ontar_model(self.inputs_sg)
all_vars = slim.get_variables_to_restore()
on_vars = {v.op.name[6:]: v for v in all_vars if v.name.startswith('ontar')}
sess.run(create_init_op())
load_ckpt(sess, ontar_model_dir, variables_to_restore=on_vars)
示例13: get_variables_to_restore
# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_variables_to_restore(include_vars=[], exclude_global_pool=False):
variables_to_restore = []
for var in slim.get_model_variables():
if exclude_global_pool and 'global_pool' in var.op.name:
#print(var)
continue
variables_to_restore.append(var)
for var in slim.get_variables_to_restore(include=include_vars):
if exclude_global_pool and 'global_pool' in var.op.name:
#print(var)
continue
variables_to_restore.append(var)
return variables_to_restore