当前位置: 首页>>代码示例>>Python>>正文


Python monitored_session.Scaffold方法代码示例

本文整理汇总了Python中tensorflow.python.training.monitored_session.Scaffold方法的典型用法代码示例。如果您正苦于以下问题:Python monitored_session.Scaffold方法的具体用法?Python monitored_session.Scaffold怎么用?Python monitored_session.Scaffold使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.training.monitored_session的用法示例。


在下文中一共展示了monitored_session.Scaffold方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _scaffold_with_init

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def _scaffold_with_init(scaffold, saver, checkpoint_path):
  """Creates a scaffold that loads the given checkpoint using an init_fn.

  Args:
    scaffold: The scaffold to copy.
    saver: The saver to use when restoring the checkpoint.
    checkpoint_path: An absolute path to a checkpoint.

  Returns:
    A scaffold with an init_fn that loads the given checkpoint. If the scaffold
    provided already has an init_fn, the scaffold is returned unchanged.
  """

  def restore_checkpoint(_, session):
    saver.restore(session, checkpoint_path)

  if not scaffold.init_fn:
    scaffold = monitored_session.Scaffold(
        init_op=scaffold.init_op,
        init_feed_dict=scaffold.init_feed_dict,
        init_fn=restore_checkpoint,
        ready_op=scaffold.ready_op,
        local_init_op=scaffold.local_init_op,
        summary_op=scaffold.summary_op,
        saver=scaffold.saver)
  return scaffold 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:28,代码来源:evaluation.py

示例2: setUp

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def setUp(self):
    self.model_dir = tempfile.mkdtemp()
    self.graph = tf.Graph()
    with self.graph.as_default():
      self.scaffold = monitored_session.Scaffold()
      self.global_step = tf.contrib.framework.get_or_create_global_step()
      self.train_op = tf.assign_add(self.global_step, 1) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:9,代码来源:basic_session_run_hooks_test.py

示例3: test_raise_when_scaffold_and_summary_op_both_present

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def test_raise_when_scaffold_and_summary_op_both_present(self):
    with self.assertRaises(ValueError):
      tf.train.SummarySaverHook(scaffold=tf.train.Scaffold(),
                                summary_op=self.summary_op) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:6,代码来源:basic_session_run_hooks_test.py

示例4: correlation_matrix

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def correlation_matrix(nb_batches, checkpoint_dir):
    """Computes logits and labels of the input posts and save them as numpy files.
    
    Parameters:
        checkpoint_dir: Checkpoint of the saved model during training.
    """
    with tf.Graph().as_default():
        config = _CONFIG.copy()
        config['mode'] = 'validation'
        model = DeepSentiment(config)

        # Load model
        checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
        scaffold = monitored_session.Scaffold(
            init_op=None, init_feed_dict=None,
            init_fn=None, saver=None)
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_filename_with_path=checkpoint_path,
            master='',
            config=None)

        posts_logits = []
        posts_labels = []
        with monitored_session.MonitoredSession( # Generate queue
            session_creator=session_creator, hooks=None) as session:
            for i in range(nb_batches):
                np_logits, np_labels = session.run([model.logits, model.labels])
                posts_logits.append(np_logits)
                posts_labels.append(np_labels)

    posts_logits, posts_labels = np.vstack(posts_logits), np.hstack(posts_labels)
    np.save('data/posts_logits.npy', posts_logits)
    np.save('data/posts_labels.npy', posts_labels)
    return posts_logits, posts_labels 
开发者ID:anthonyhu,项目名称:tumblr-emotions,代码行数:37,代码来源:im_text_rnn_model.py

示例5: __new__

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def __new__(cls, model, step=None, train_op=None, **kwargs):
    if "mode" not in kwargs:
      raise ValueError("Must provide a mode (TRAIN/EVAL/PREDICT) when "
                       "creating an EstimatorSpec")

    if train_op is None:
      raise ValueError(
          "Must provide train_op for creating a PruningEstimatorSpec")

    def _get_step_increment_ops(model, step=None):
      """Returns ops to increment the pruning_step in the prunable layers."""
      increment_ops = []

      for layer in model.layers:
        if isinstance(layer, PruneLowMagnitude):
          if step is None:
            # Add ops to increment the pruning_step by 1
            increment_ops.append(state_ops.assign_add(layer.pruning_step, 1))
          else:
            increment_ops.append(
                state_ops.assign(layer.pruning_step,
                                 math_ops.cast(step, dtypes.int32)))

      return control_flow_ops.group(increment_ops)

    pruning_ops = []
    # Grab the ops to update pruning step in every prunable layer
    step_increment_ops = _get_step_increment_ops(model, step)
    pruning_ops.append(step_increment_ops)
    # Grab the model updates.
    pruning_ops.append(model.updates)

    kwargs["train_op"] = control_flow_ops.group(pruning_ops, train_op)

    def init_fn(scaffold, session):  # pylint: disable=unused-argument
      return session.run(step_increment_ops)

    def get_new_scaffold(old_scaffold):
      if old_scaffold.init_fn is None:
        return monitored_session.Scaffold(
            init_fn=init_fn, copy_from_scaffold=old_scaffold)
      # TODO(suyoggupta): Figure out a way to merge the init_fn of the
      # original scaffold with the one defined above.
      raise ValueError("Scaffold provided to PruningEstimatorSpec must not "
                       "set an init_fn.")

    scaffold = monitored_session.Scaffold(init_fn=init_fn)
    if "scaffold" in kwargs:
      scaffold = get_new_scaffold(kwargs["scaffold"])

    kwargs["scaffold"] = scaffold

    return super(PruningEstimatorSpec, cls).__new__(cls, **kwargs) 
开发者ID:tensorflow,项目名称:model-optimization,代码行数:55,代码来源:estimator_utils.py

示例6: _train_model

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def _train_model(self, input_fn, hooks):
    all_hooks = []
    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, labels = input_fn()
      self._check_inputs(features, labels)
      model_fn_ops = self._call_legacy_get_train_ops(features, labels)
      ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
      all_hooks.extend([
          basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
          basic_session_run_hooks.LoggingTensorHook(
              {
                  'loss': model_fn_ops.loss,
                  'step': global_step
              },
              every_n_iter=100)
      ])
      all_hooks.extend(hooks)

      scaffold = model_fn_ops.training_scaffold or monitored_session.Scaffold()
      if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            saver.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                defer_build=True))

      chief_hooks = []
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        saver_hook_exists = any([
            isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
            for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
                      model_fn_ops.training_chief_hooks)
        ])
        if not saver_hook_exists:
          chief_hooks = [
              basic_session_run_hooks.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=scaffold)
          ]
      with monitored_session.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=scaffold,
          hooks=all_hooks + model_fn_ops.training_hooks,
          chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self.config.tf_config) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
      summary_io.SummaryWriterCache.clear()
      return loss 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:63,代码来源:estimator.py

示例7: outliers_detection

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def outliers_detection(checkpoint_dir):
    """Find outliers using Euclidean distance in the last dense layer.
    
    Parameters:
        checkpoint_dir: Checkpoint of the saved model during training.
    """
    with tf.Graph().as_default():
        config = _CONFIG.copy()
        config['mode'] = 'validation'
        model = DeepSentiment(config)

        # Load model
        checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
        scaffold = monitored_session.Scaffold(
            init_op=None, init_feed_dict=None,
            init_fn=None, saver=None)
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_filename_with_path=checkpoint_path,
            master='',
            config=None)

        im_features_size = config['im_features_size']
        rnn_size = config['rnn_size']
        dense_mean = np.zeros((im_features_size + rnn_size))
        with monitored_session.MonitoredSession( # Generate queue
            session_creator=session_creator, hooks=None) as session:
            batch_size = config['batch_size']
            nb_batches = model.dataset.num_samples / batch_size
            for i in range(nb_batches):
                current_dense = session.run(model.concat_features)
                weight = float(i) * batch_size / ((i+1) * batch_size)
                dense_mean = weight * dense_mean + (1-weight) * current_dense.mean(axis=0)

            # Now look at outliers
            max_norms = np.zeros((batch_size))
            max_post_ids = np.zeros((batch_size))
            max_logits = np.zeros((batch_size, model.dataset.num_classes))
            for i in range(nb_batches):
                current_dense, np_post_ids, current_logits = session.run([model.concat_features, model.post_ids,
                    model.logits])
                current_diff = np.linalg.norm(current_dense - dense_mean, axis=1)
                for k in range(batch_size):
                    if current_diff[k] > max_norms[k]:
                        max_norms[k] = current_diff[k]
                        max_post_ids[k] = np_post_ids[k]
                        max_logits[k] = current_logits[k]
            
    np.save('data/max_norms.npy', max_norms)
    np.save('data/max_post_ids.npy', max_post_ids)
    np.save('data/max_logits.npy', max_logits)
    return max_norms, max_post_ids, max_logits 
开发者ID:anthonyhu,项目名称:tumblr-emotions,代码行数:54,代码来源:im_text_rnn_model.py

示例8: day_of_week_trend

# 需要导入模块: from tensorflow.python.training import monitored_session [as 别名]
# 或者: from tensorflow.python.training.monitored_session import Scaffold [as 别名]
def day_of_week_trend(checkpoint_dir):
    """Compute day of week trend.
    
    Parameters:
        checkpoint_dir: Checkpoint of the saved model during training.
    """
    with tf.Graph().as_default():
        config = _CONFIG.copy()
        config['mode'] = 'validation'
        model = DeepSentiment(config)

        # Load model
        checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
        scaffold = monitored_session.Scaffold(
            init_op=None, init_feed_dict=None,
            init_fn=None, saver=None)
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_filename_with_path=checkpoint_path,
            master='',
            config=None)

        posts_logits = []
        posts_labels = []
        posts_days = []
        posts_ids = []
        with monitored_session.MonitoredSession( # Generate queue
            session_creator=session_creator, hooks=None) as session:
            batch_size = config['batch_size']
            nb_batches = model.dataset.num_samples / batch_size
            for i in range(nb_batches):
                np_logits, np_labels, np_days, np_post_ids = session.run([model.logits, model.labels, 
                    model.days, model.post_ids])
                posts_logits.append(np_logits)
                posts_labels.append(np_labels)
                posts_days.append(np_days)
                posts_ids.append(np_post_ids)

    posts_logits, posts_labels = np.vstack(posts_logits), np.hstack(posts_labels)
    posts_days, posts_ids = np.hstack(posts_days), np.hstack(posts_ids)
    np.save('data/posts_logits_week.npy', posts_logits)
    np.save('data/posts_labels_week.npy', posts_labels)
    np.save('data/posts_days_week.npy', posts_days)
    np.save('data/posts_ids_week.npy', posts_ids)
    return posts_logits, posts_labels, posts_days, posts_ids 
开发者ID:anthonyhu,项目名称:tumblr-emotions,代码行数:47,代码来源:im_text_rnn_model.py


注:本文中的tensorflow.python.training.monitored_session.Scaffold方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。