当前位置: 首页>>代码示例>>Python>>正文


Python slim.get_variables_to_restore方法代码示例

本文整理汇总了Python中tensorflow.contrib.slim.get_variables_to_restore方法的典型用法代码示例。如果您正苦于以下问题:Python slim.get_variables_to_restore方法的具体用法?Python slim.get_variables_to_restore怎么用?Python slim.get_variables_to_restore使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.slim的用法示例。


在下文中一共展示了slim.get_variables_to_restore方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def main():
    args = parse_args()

    with tf.Session(graph=tf.Graph()) as session:
        input_var = tf.placeholder(
            tf.uint8, (None, 128, 64, 3), name="images")
        image_var = tf.map_fn(
            lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
            back_prop=False)

        factory_fn = _network_factory()
        features, _ = factory_fn(image_var, reuse=None)
        features = tf.identity(features, name="features")

        saver = tf.train.Saver(slim.get_variables_to_restore())
        saver.restore(session, args.checkpoint_in)

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            session, tf.get_default_graph().as_graph_def(),
            [features.name.split(":")[0]])
        with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
            file_handle.write(output_graph_def.SerializeToString()) 
开发者ID:nwojke,项目名称:deep_sort,代码行数:24,代码来源:freeze_model.py

示例2: build_pretrained_graph

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def build_pretrained_graph(
      self, images, resnet_layer, checkpoint, is_training, reuse=False):
    """See baseclass."""
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
      _, endpoints = resnet_v2.resnet_v2_50(
          images, is_training=is_training, reuse=reuse)
      resnet_layer = 'resnet_v2_50/block%d' % resnet_layer
      resnet_output = endpoints[resnet_layer]
      resnet_variables = slim.get_variables_to_restore()
      resnet_variables = [
          i for i in resnet_variables if 'global_step' not in i.name]
      if is_training and not reuse:
        init_saver = tf.train.Saver(resnet_variables)
        def init_fn(scaffold, sess):
          del scaffold
          init_saver.restore(sess, checkpoint)
      else:
        init_fn = None

      return resnet_output, resnet_variables, init_fn 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:22,代码来源:model.py

示例3: restore

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def restore(self, path = None, restore_opt = True, ul_only = False):
    if path is None:
      path = tf.train.latest_checkpoint(self.pr.train_dir)      
    print 'Restoring:', path
    var_list = slim.get_variables_to_restore()
    for x in var_list:
      print x.name
    print
    var_list = slim.get_variables_to_restore()
    if not restore_opt:
      opt_names = ['Adam', 'beta1_power', 'beta2_power', 'Momentum'] + ['cls']# + ['renorm_mean_weight', 'renorm_stddev_weight', 'moving_mean', 'renorm']
      print 'removing bn gamma'
      opt_names += ['gamma']
      var_list = [x for x in var_list if not any(name in x.name for name in opt_names)]
    if ul_only:
      var_list = [x for x in var_list if not x.name.startswith('lb/') and ('global_step' not in x.name)]
    #var_list = [x for x in var_list if ('global_step' not in x.name)]
    print 'Restoring variables:'
    for x in var_list:
      print x.name
    tf.train.Saver(var_list).restore(self.sess, path)
    # print 'TEST: restoring all'
    # tf.train.Saver().restore(self.sess, path) 
开发者ID:andrewowens,项目名称:multisensory,代码行数:25,代码来源:videocls.py

示例4: restore

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def restore(self, path = None, restore_opt = True, restore_resnet18_blocks = True, restore_dilation_blocks = True):
    if path is None:
      path = tf.train.latest_checkpoint(self.pr.train_dir)      
    print 'Restoring from:', path
    var_list = slim.get_variables_to_restore()
    opt_names = ['Adam', 'beta1_power', 'beta2_power', 'Momentum', 'cache']
    if not restore_resnet18_blocks:
      opt_names += ['conv2_2_', 'conv3_2_', 'conv4_2_', 'conv5_2_']

    if not restore_opt:
      var_list = [x for x in var_list if not any(name in x.name for name in opt_names)]

    print 'Restoring:'
    for x in var_list:
      print x.name
    print
    tf.train.Saver(var_list).restore(self.sess, path) 
开发者ID:andrewowens,项目名称:multisensory,代码行数:19,代码来源:shift_net.py

示例5: variables_to_restore

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def variables_to_restore(scope=None, strip_scope=False):
  """Returns a list of variables to restore for the specified list of methods.

  It is supposed that variable name starts with the method's scope (a prefix
  returned by _method_scope function).

  Args:
    methods_names: a list of names of configurable methods.
    strip_scope: if True will return variable names without method's scope.
      If methods_names is None will return names unchanged.
    model_scope: a scope for a whole model.

  Returns:
    a dictionary mapping variable names to variables for restore.
  """
  if scope:
    variable_map = {}
    method_variables = slim.get_variables_to_restore(include=[scope])
    for var in method_variables:
      if strip_scope:
        var_name = var.op.name[len(scope) + 1:]
      else:
        var_name = var.op.name
      variable_map[var_name] = var

    return variable_map
  else:
    return {v.op.name: v for v in slim.get_variables_to_restore()} 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:30,代码来源:utils.py

示例6: get_repr_from_image

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_repr_from_image(images_reshaped, modalities, data_augment, encoder,
                        freeze_conv, wt_decay, is_training):
  # Pass image through lots of convolutional layers, to obtain pool5
  if modalities == ['rgb']:
    with tf.name_scope('pre_rgb'):
      x = (images_reshaped + 128.) / 255. # Convert to brightness between 0 and 1.
      if data_augment.relight and is_training:
        x = tf_utils.distort_image(x, fast_mode=data_augment.relight_fast)
      x = (x-0.5)*2.0
    scope_name = encoder
  elif modalities == ['depth']:
    with tf.name_scope('pre_d'):
      d_image = images_reshaped
      x = 2*(d_image[...,0] - 80.0)/100.0
      y = d_image[...,1]
      d_image = tf.concat([tf.expand_dims(x, -1), tf.expand_dims(y, -1)], 3)
      x = d_image
    scope_name = 'd_'+encoder

  resnet_is_training = is_training and (not freeze_conv)
  with slim.arg_scope(resnet_v2.resnet_utils.resnet_arg_scope(resnet_is_training)):
    fn = getattr(tf_utils, encoder)
    x, end_points = fn(x, num_classes=None, global_pool=False,
                       output_stride=None, reuse=None,
                       scope=scope_name)
  vars_ = slim.get_variables_to_restore()

  conv_feat = x
  return conv_feat, vars_ 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:31,代码来源:nav_utils.py

示例7: _image_to_head

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def _image_to_head(self, is_training, reuse=None):
        with slim.arg_scope(mobilenet_v2.training_scope(is_training=is_training)):
            net, endpoints = mobilenet_v2.mobilenet_base(self._image, conv_defs=CTPN_DEF)

        self.variables_to_restore = slim.get_variables_to_restore()

        self._act_summaries.append(net)
        self._layers['head'] = net

        return net 
开发者ID:Sanster,项目名称:tf_ctpn,代码行数:12,代码来源:mobilenet_v2.py

示例8: get_variables_to_restore

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_variables_to_restore(self, variables, var_keep_dic):
        pass 
开发者ID:Sanster,项目名称:tf_ctpn,代码行数:4,代码来源:mobilenet_v2.py

示例9: get_model_init_fn

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.

    Args:
      train_logdir: Log directory for training.
      tf_initial_checkpoint: TensorFlow checkpoint for initialization.
      initialize_last_layer: Initialize last layer or not.
      last_layers: Last layers of the model.
      ignore_missing_vars: Ignore missing variables in the checkpoint.

    Returns:
      Initialization function.
    """
    if tf_initial_checkpoint is None:
        tf.logging.info('Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.logging.info('Ignoring initialization; other checkpoint exists')
        return None

    tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list)

    if variables_to_restore:
        return slim.assign_from_checkpoint_fn(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)
    return None 
开发者ID:sercant,项目名称:mobile-segmentation,代码行数:42,代码来源:train_utils.py

示例10: build_inceptionv3_graph

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def build_inceptionv3_graph(images, endpoint, is_training, checkpoint,
                            reuse=False):
  """Builds an InceptionV3 model graph.

  Args:
    images: A 4-D float32 `Tensor` of batch images.
    endpoint: String, name of the InceptionV3 endpoint.
    is_training: Boolean, whether or not to build a training or inference graph.
    checkpoint: String, path to the pretrained model checkpoint.
    reuse: Boolean, whether or not we are reusing the embedder.
  Returns:
    inception_output: `Tensor` holding the InceptionV3 output.
    inception_variables: List of inception variables.
    init_fn: Function to initialize the weights (if not reusing, then None).
  """
  with slim.arg_scope(inception.inception_v3_arg_scope()):
    _, endpoints = inception.inception_v3(
        images, num_classes=1001, is_training=is_training)
    inception_output = endpoints[endpoint]
    inception_variables = slim.get_variables_to_restore()
    inception_variables = [
        i for i in inception_variables if 'global_step' not in i.name]
    if is_training and not reuse:
      init_saver = tf.train.Saver(inception_variables)
      def init_fn(scaffold, sess):
        del scaffold
        init_saver.restore(sess, checkpoint)
    else:
      init_fn = None
    return inception_output, inception_variables, init_fn 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:32,代码来源:model.py

示例11: load_ckpt

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def load_ckpt(sess, model_dir, variables_to_restore=None):
    ckpt = tf.train.get_checkpoint_state(model_dir)
    model_path = ckpt.model_checkpoint_path
    if variables_to_restore is None:
        variables_to_restore = slim.get_variables_to_restore()
    restore_op, restore_fd = slim.assign_from_checkpoint(
        model_path, variables_to_restore)
    sess.run(restore_op, feed_dict=restore_fd)
    print(f'{model_path} loaded') 
开发者ID:bm2-lab,项目名称:DeepCRISPR,代码行数:11,代码来源:deepcrispr_src.py

示例12: __init__

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def __init__(self, sess, ontar_model_dir, is_reg=False, seq_feature_only=False):
        self.sess = sess
        if seq_feature_only:
            self.inputs_sg = tf.placeholder(dtype=tf.float32, shape=[None, 1, 23, 4])
        else:
            self.inputs_sg = tf.placeholder(dtype=tf.float32, shape=[None, 1, 23, 8])
        if is_reg:
            self.pred_ontar = build_ontar_reg_model(self.inputs_sg)
        else:
            self.pred_ontar = build_ontar_model(self.inputs_sg)
        all_vars = slim.get_variables_to_restore()
        on_vars = {v.op.name[6:]: v for v in all_vars if v.name.startswith('ontar')}
        sess.run(create_init_op())
        load_ckpt(sess, ontar_model_dir, variables_to_restore=on_vars) 
开发者ID:bm2-lab,项目名称:DeepCRISPR,代码行数:16,代码来源:deepcrispr_src.py

示例13: get_variables_to_restore

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_variables_to_restore [as 别名]
def get_variables_to_restore(include_vars=[], exclude_global_pool=False):
    variables_to_restore = []
    for var in slim.get_model_variables():
        if exclude_global_pool and 'global_pool' in var.op.name:
            #print(var)
            continue
        variables_to_restore.append(var)
    for var in slim.get_variables_to_restore(include=include_vars):
        if exclude_global_pool and 'global_pool' in var.op.name:
            #print(var)
            continue
        variables_to_restore.append(var)
    return variables_to_restore 
开发者ID:vicwer,项目名称:sense_classification,代码行数:15,代码来源:multi_gpus_train.py


注:本文中的tensorflow.contrib.slim.get_variables_to_restore方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。