当前位置: 首页>>代码示例>>Python>>正文


Python slim.get_or_create_global_step方法代码示例

本文整理汇总了Python中tensorflow.contrib.slim.get_or_create_global_step方法的典型用法代码示例。如果您正苦于以下问题:Python slim.get_or_create_global_step方法的具体用法?Python slim.get_or_create_global_step怎么用?Python slim.get_or_create_global_step使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.slim的用法示例。


在下文中一共展示了slim.get_or_create_global_step方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_or_create_global_step [as 别名]
def main(_):
  if not tf.gfile.Exists(FLAGS.eval_log_dir):
    tf.gfile.MakeDirs(FLAGS.eval_log_dir)

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())
  endpoints = model.create_base(data.images, labels_one_hot=None)
  model.create_loss(data, endpoints)
  eval_ops = model.create_summaries(
      data, endpoints, dataset.charset, is_training=False)
  slim.get_or_create_global_step()
  session_config = tf.ConfigProto(device_count={"GPU": 0})
  slim.evaluation.evaluation_loop(
      master=FLAGS.master,
      checkpoint_dir=FLAGS.train_log_dir,
      logdir=FLAGS.eval_log_dir,
      eval_op=eval_ops,
      num_evals=FLAGS.num_batches,
      eval_interval_secs=FLAGS.eval_interval_secs,
      max_number_of_evaluations=FLAGS.number_of_steps,
      session_config=session_config) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:30,代码来源:eval.py

示例2: setup_training

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_or_create_global_step [as 别名]
def setup_training(loss_op, initial_learning_rate, steps_per_decay,
                   learning_rate_decay, momentum, max_steps,
                   sync=False, adjust_lr_sync=True,
                   num_workers=1, replica_id=0, vars_to_optimize=None, 
                   clip_gradient_norm=0, typ=None, momentum2=0.999,
                   adam_eps=1e-8):
  if sync and adjust_lr_sync:
    initial_learning_rate = initial_learning_rate * num_workers
    max_steps = np.int(max_steps / num_workers)
    steps_per_decay = np.int(steps_per_decay / num_workers)

  global_step_op = slim.get_or_create_global_step()
  lr_op          = tf.train.exponential_decay(initial_learning_rate,
    global_step_op, steps_per_decay, learning_rate_decay, staircase=True)
  if typ == 'sgd':
    optimizer      = tf.train.MomentumOptimizer(lr_op, momentum)
  elif typ == 'adam':
    optimizer      = tf.train.AdamOptimizer(learning_rate=lr_op, beta1=momentum,
                                            beta2=momentum2, epsilon=adam_eps)
  
  if sync:
    
    sync_optimizer = tf.train.SyncReplicasOptimizer(optimizer, 
                                               replicas_to_aggregate=num_workers, 
                                               replica_id=replica_id, 
                                               total_num_replicas=num_workers)
    train_op       = slim.learning.create_train_op(loss_op, sync_optimizer,
                                                   variables_to_train=vars_to_optimize,
                                                   clip_gradient_norm=clip_gradient_norm)
  else:
    sync_optimizer = None
    train_op       = slim.learning.create_train_op(loss_op, optimizer,
                                                   variables_to_train=vars_to_optimize,
                                                   clip_gradient_norm=clip_gradient_norm)
    should_stop_op = tf.greater_equal(global_step_op, max_steps)
  return lr_op, global_step_op, train_op, should_stop_op, optimizer, sync_optimizer 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:38,代码来源:tf_utils.py

示例3: _setup_np_inference

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_or_create_global_step [as 别名]
def _setup_np_inference(self, np_images, checkpoint_path):
    """Sets up and restores inference graph, creates and caches a Session."""
    tf.logging.info('Restoring model weights.')

    # Define inference over an image placeholder.
    _, height, width, _ = np.shape(np_images)
    image_placeholder = tf.placeholder(
        tf.float32, shape=(None, height, width, 3))

    # Preprocess batch.
    preprocessed = self.preprocess_data(image_placeholder, is_training=False)

    # Unscale and jpeg encode preprocessed images for display purposes.
    im_strings = preprocessing.unscale_jpeg_encode(preprocessed)

    # Do forward pass to get embeddings.
    embeddings = self.forward(preprocessed, is_training=False)

    # Create a saver to restore model variables.
    tf.train.get_or_create_global_step()
    saver = tf.train.Saver(tf.all_variables())

    self._image_placeholder = image_placeholder
    self._batch_encoded = embeddings

    self._np_inf_tensor_dict = {
        'embeddings': embeddings,
        'raw_image_strings': im_strings,
    }

    # Create a session and restore model variables.
    self._sess = tf.Session()
    saver.restore(self._sess, checkpoint_path) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:35,代码来源:base_estimator.py

示例4: get_restorer

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_or_create_global_step [as 别名]
def get_restorer(self):
        checkpoint_path = tf.train.latest_checkpoint(os.path.join(cfgs.TRAINED_CKPT, cfgs.VERSION))

        if checkpoint_path != None:
            if cfgs.RESTORE_FROM_RPN:
                print('___restore from rpn___')
                model_variables = slim.get_model_variables()
                restore_variables = [var for var in model_variables if not var.name.startswith('FastRCNN_Head')] + \
                                    [slim.get_or_create_global_step()]
                for var in restore_variables:
                    print(var.name)
                restorer = tf.train.Saver(restore_variables)
            else:
                restorer = tf.train.Saver()
            print("model restore from :", checkpoint_path)
        else:
            checkpoint_path = cfgs.PRETRAINED_CKPT
            print("model restore from pretrained mode, path is :", checkpoint_path)

            model_variables = slim.get_model_variables()
            # print(model_variables)

            def name_in_ckpt_rpn(var):
                return var.op.name

            def name_in_ckpt_fastrcnn_head(var):
                '''
                Fast-RCNN/resnet_v1_50/block4 -->resnet_v1_50/block4
                :param var:
                :return:
                '''
                return '/'.join(var.op.name.split('/')[1:])

            nameInCkpt_Var_dict = {}
            for var in model_variables:
                if var.name.startswith('Fast-RCNN/'+self.base_network_name+'/block4'):
                    var_name_in_ckpt = name_in_ckpt_fastrcnn_head(var)
                    nameInCkpt_Var_dict[var_name_in_ckpt] = var
                else:
                    if var.name.startswith(self.base_network_name):
                        var_name_in_ckpt = name_in_ckpt_rpn(var)
                        nameInCkpt_Var_dict[var_name_in_ckpt] = var
                    else:
                        continue
            restore_variables = nameInCkpt_Var_dict
            for key, item in restore_variables.items():
                print("var_in_graph: ", item.name)
                print("var_in_ckpt: ", key)
                print(20*"---")
            restorer = tf.train.Saver(restore_variables)
            print(20 * "****")
            print("restore from pretrained_weighs in IMAGE_NET")
        return restorer, checkpoint_path 
开发者ID:DetectionTeamUCAS,项目名称:R2CNN_Faster-RCNN_Tensorflow,代码行数:55,代码来源:build_whole_network.py

示例5: get_restorer

# 需要导入模块: from tensorflow.contrib import slim [as 别名]
# 或者: from tensorflow.contrib.slim import get_or_create_global_step [as 别名]
def get_restorer(self):
        checkpoint_path = tf.train.latest_checkpoint(os.path.join(cfgs.TRAINED_CKPT, cfgs.VERSION))

        if checkpoint_path != None:
            if cfgs.RESTORE_FROM_RPN:
                print('___restore from rpn___')
                model_variables = slim.get_model_variables()
                restore_variables = [var for var in model_variables if not var.name.startswith('FastRCNN_Head')] + \
                                    [slim.get_or_create_global_step()]
                for var in restore_variables:
                    print(var.name)
                restorer = tf.train.Saver(restore_variables)
            else:
                restorer = tf.train.Saver()
            print("model restore from :", checkpoint_path)
        else:
            checkpoint_path = cfgs.PRETRAINED_CKPT
            print("model restore from pretrained mode, path is :", checkpoint_path)

            model_variables = slim.get_model_variables()

            # for var in model_variables:
            #     print(var.name)
            # print(20*"__++__++__")

            def name_in_ckpt_rpn(var):
                return var.op.name

            def name_in_ckpt_fastrcnn_head(var):
                '''
                Fast-RCNN/resnet_v1_50/block4 -->resnet_v1_50/block4
                Fast-RCNN/MobilenetV2/** -- > MobilenetV2 **
                :param var:
                :return:
                '''
                return '/'.join(var.op.name.split('/')[1:])

            nameInCkpt_Var_dict = {}
            for var in model_variables:
                if var.name.startswith('Fast-RCNN/'+self.base_network_name):  # +'/block4'
                    var_name_in_ckpt = name_in_ckpt_fastrcnn_head(var)
                    nameInCkpt_Var_dict[var_name_in_ckpt] = var
                else:
                    if var.name.startswith(self.base_network_name):
                        var_name_in_ckpt = name_in_ckpt_rpn(var)
                        nameInCkpt_Var_dict[var_name_in_ckpt] = var
                    else:
                        continue
            restore_variables = nameInCkpt_Var_dict
            for key, item in restore_variables.items():
                print("var_in_graph: ", item.name)
                print("var_in_ckpt: ", key)
                print(20*"___")
            restorer = tf.train.Saver(restore_variables)
            print(20 * "****")
            print("restore from pretrained_weighs in IMAGE_NET")
        return restorer, checkpoint_path 
开发者ID:Thinklab-SJTU,项目名称:R3Det_Tensorflow,代码行数:59,代码来源:build_whole_network_refine_retinanet.py


注:本文中的tensorflow.contrib.slim.get_or_create_global_step方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。