当前位置: 首页>>代码示例>>Python>>正文


Python framework.create_global_step函数代码示例

本文整理汇总了Python中tensorflow.contrib.framework.create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python create_global_step函数的具体用法?Python create_global_step怎么用?Python create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了create_global_step函数的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _infer_model

  def _infer_model(
      self, input_fn, feed_fn=None, outputs=None, as_iterable=False):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      predictions = self._get_predict_ops(features)
      # If predictions is single output - wrap it into dict, and remember to
      # return not a dict.
      return_dict = isinstance(predictions, dict)
      if not return_dict:
        predictions = {'predictions': predictions}

      # Filter what to run predictions on, if outputs provided.
      if outputs:
        existing_keys = predictions.keys()
        predictions = {
            key: value for key, value in predictions.items() if key in outputs
        }
        if not predictions:
          raise ValueError('Expected to run at least one output from %s, '
                           'provided %s.' % (existing_keys, outputs))

      if as_iterable:
        return self._infer_model_as_iterable(
            checkpoint_path, predictions, feed_fn, return_dict)
      else:
        return self._infer_model_single(
            checkpoint_path, predictions, feed_fn, return_dict)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:35,代码来源:estimator.py

示例2: _infer_model

  def _infer_model(self,
                   x=None, input_fn=None, feed_fn=None,
                   batch_size=None, axis=None, proba=False):
    # Converts inputs into tf.DataFrame / tf.Series.
    batch_size = -1 if batch_size is None else batch_size
    if x is not None:
      input_fn, feed_fn = _get_predict_input_fn(x, batch_size)

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features, _ = input_fn()
      predictions = self._get_predict_ops(features)
      if not isinstance(predictions, dict):
        predictions = {'predictions': predictions}
      # TODO(ipolosukhin): Support batching
      if feed_fn is None:
        return infer(checkpoint_path, predictions)
      preds = {}
      while True:
        try:
          feed_dict = feed_fn()
        except StopIteration:
          break
        if feed_dict is None:
          break
        outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
        for key in outputs:
          if key not in preds:
            preds[key] = []
          preds[key].append(outputs[key])
      for key in preds:
        preds[key] = np.concatenate(preds[key], axis=0)
      return preds
开发者ID:01-,项目名称:tensorflow,代码行数:35,代码来源:estimator.py

示例3: _infer_model

  def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None):
    # Converts inputs into tf.DataFrame / tf.Series.
    batch_size = -1 if batch_size is None else batch_size
    if x is not None:
      input_fn, feed_fn = _get_predict_input_fn(x, None, batch_size)

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      predictions = self._get_predict_ops(features)
      return_dict = True
      if not isinstance(predictions, dict):
        predictions, return_dict = {'predictions': predictions}, False
      if feed_fn is None:
        preds = infer(checkpoint_path, predictions)
      else:
        preds = {}
        def _feed_fn():
          while True:
            yield feed_fn()
        outputs = graph_actions.run_feeds(
            output_dict=predictions,
            feed_dicts=_feed_fn(),
            restore_checkpoint_path=checkpoint_path)
        for key in predictions:
          preds[key] = np.concatenate(
              [output[key] for output in outputs], axis=0)
      if return_dict:
        return preds
      return preds['predictions']
开发者ID:Baaaaam,项目名称:tensorflow,代码行数:32,代码来源:estimator.py

示例4: _infer_model

  def _infer_model(
      self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)

      # The default return type of _get_predict_ops is ModelFnOps. But there are
      # some subclasses of tf.contrib.learn.Estimator which override this
      # method and use the legacy signature, namely _get_predict_ops returns a
      # `predictions` Tensor or dict or Tensors. The following else-statement
      # code covers these cases, but will soon be deleted after the subclasses
      # are updated.
      # TODO(b/32664904): Update subclasses and delete the else-statement.
      infer_ops = self._get_predict_ops(features)
      if isinstance(infer_ops, model_fn_lib.ModelFnOps):  # Default signature
        predictions = infer_ops.predictions
      else:  # Legacy signature
        predictions = infer_ops

      # If predictions is single output - wrap it into dict, and remember to
      # return not a dict.
      return_dict = isinstance(predictions, dict)
      if not return_dict:
        predictions = {'predictions': predictions}

      # Filter what to run predictions on, if outputs provided.
      if outputs:
        existing_keys = predictions.keys()
        predictions = {
            key: value
            for key, value in six.iteritems(predictions) if key in outputs
        }
        if not predictions:
          raise ValueError('Expected to run at least one output from %s, '
                           'provided %s.' % (existing_keys, outputs))

      if as_iterable:
        return self._infer_model_as_iterable(
            checkpoint_path, predictions, feed_fn, return_dict)
      else:
        return self._infer_model_single(
            checkpoint_path, predictions, feed_fn, return_dict)
开发者ID:HKUST-SING,项目名称:tensorflow,代码行数:49,代码来源:estimator.py

示例5: _evaluate_model

  def _evaluate_model(self,
                      input_fn,
                      steps,
                      feed_fn=None,
                      metrics=None,
                      name=''):
    if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
      return

    checkpoint_path = self._model_dir
    eval_dir = os.path.join(self._model_dir, 'eval' if not name else
                            'eval_' + name)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      eval_dict = self._get_eval_ops(features, targets, metrics)
      update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
      eval_results, _ = evaluate(graph=g,
                                 output_dir=eval_dir,
                                 checkpoint_path=checkpoint_path,
                                 eval_dict=eval_dict,
                                 update_op=update_op,
                                 global_step_tensor=global_step,
                                 supervisor_master=self._config.master,
                                 feed_fn=feed_fn,
                                 max_steps=steps)
      return eval_results
开发者ID:Baaaaam,项目名称:tensorflow,代码行数:29,代码来源:estimator.py

示例6: _infer_model

  def _infer_model(self, x, batch_size=None, axis=None, proba=False):
    # Converts inputs into tf.DataFrame / tf.Series.
    batch_size = -1 if batch_size is None else batch_size
    input_fn, feed_fn = _get_predict_input_fn(x, batch_size)

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features, _ = input_fn()
      feed_dict = feed_fn() if feed_fn is not None else None
      predictions = self._get_predict_ops(features)
      if not isinstance(predictions, dict):
        predictions = {'predictions': predictions}
      # TODO(ipolosukhin): Support batching
      return infer(checkpoint_path, predictions, feed_dict=feed_dict)
开发者ID:Absarvar,项目名称:tensorflow,代码行数:16,代码来源:estimator.py

示例7: _infer_model

  def _infer_model(self, input_fn, feed_fn=None, outputs=None):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      predictions = self._get_predict_ops(features)
      # If predictions is single output - wrap it into dict, and remember to
      # return not a dict.
      return_dict = True
      if not isinstance(predictions, dict):
        predictions, return_dict = {'predictions': predictions}, False
      # Filter what to run predictions on, if outputs provided.
      if outputs:
        existing_keys = predictions.keys()
        predictions = {
            key: value for key, value in predictions.items() if key in outputs
        }
        if not predictions:
          raise ValueError('Expected to run at least one output from %s, '
                           'provided %s.' % (existing_keys, outputs))
      if feed_fn is None:
        preds = graph_actions.infer(checkpoint_path, predictions)
      else:
        preds = {}
        def _feed_fn():
          while True:
            yield feed_fn()
        outputs = graph_actions.run_feeds(
            output_dict=predictions,
            feed_dicts=_feed_fn(),
            restore_checkpoint_path=checkpoint_path)
        for key in predictions:
          preds[key] = np.concatenate(
              [output[key] for output in outputs], axis=0)
      if return_dict:
        return preds
      return preds['predictions']
开发者ID:carloscampo5200,项目名称:tensorflow,代码行数:43,代码来源:estimator.py

示例8: _evaluate_model

  def _evaluate_model(self,
                      input_fn,
                      steps,
                      feed_fn=None,
                      metrics=None,
                      name=''):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if (hasattr(self._config, 'execution_mode') and
        self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
      return None, None

    # Check that model has been trained.
    checkpoint_path = self._model_dir
    latest_path = saver.latest_checkpoint(checkpoint_path)
    if not latest_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % checkpoint_path)
    # Setup output directory.
    eval_dir = os.path.join(self._model_dir, 'eval' if not name else
                            'eval_' + name)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, labels = input_fn()
      self._check_inputs(features, labels)

      # The default return type of _get_eval_ops is ModelFnOps. But there are
      # some subclasses of tf.contrib.learn.Estimator which override this
      # method and use the legacy signature, namely _get_eval_ops returns an
      # `eval_dict` dictionary of Tensors. The following else-statement code
      # covers these cases, but will soon be deleted after the subclasses are
      # updated.
      # TODO(b/32664904): Update subclasses and delete the else-statement.
      eval_ops = self._get_eval_ops(features, labels, metrics)
      if isinstance(eval_ops, ModelFnOps):  # Default signature
        eval_dict = eval_ops.eval_metric_ops
      else:  # Legacy signature
        eval_dict = eval_ops

      update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
      eval_results, current_global_step = graph_actions.evaluate(
          graph=g,
          output_dir=eval_dir,
          checkpoint_path=checkpoint_path,
          eval_dict=eval_dict,
          update_op=update_op,
          global_step_tensor=global_step,
          supervisor_master=self._config.evaluation_master,
          feed_fn=feed_fn,
          max_steps=steps)

      return eval_results, current_global_step
开发者ID:chinnadhurai,项目名称:block_rnn,代码行数:53,代码来源:estimator.py

示例9: _train_model

  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   device_fn=None,
                   monitor=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True):
    if self._config.execution_mode not in ('all', 'train'):
      return

    # Stagger startup of worker sessions based on task id.
    sleep_secs = min(self._config.training_worker_max_startup_secs,
                     self._config.task *
                     self._config.training_worker_session_startup_stagger_secs)
    if sleep_secs:
      logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                   self._config.task)
      time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    with ops.Graph().as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      train_op, loss_op = self._get_train_ops(features, targets)
      return train(
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          log_every_steps=log_every_steps,
          supervisor_is_chief=(self._config.task == 0),
          supervisor_master=self._config.master,
          feed_fn=feed_fn,
          max_steps=steps,
          fail_on_nan_loss=fail_on_nan_loss)
开发者ID:Absarvar,项目名称:tensorflow,代码行数:41,代码来源:estimator.py

示例10: _evaluate_model

  def _evaluate_model(self,
                      input_fn,
                      steps,
                      feed_fn=None,
                      metrics=None,
                      name=''):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if (hasattr(self._config, 'execution_mode') and
        self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
      return None, None

    # Check that model has been trained.
    checkpoint_path = self._model_dir
    latest_path = saver.latest_checkpoint(checkpoint_path)
    if not latest_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % checkpoint_path)
    # Setup output directory.
    eval_dir = os.path.join(self._model_dir, 'eval' if not name else
                            'eval_' + name)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      eval_dict = self._get_eval_ops(features, targets, metrics)
      update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
      eval_results, current_global_step = graph_actions.evaluate(
          graph=g,
          output_dir=eval_dir,
          checkpoint_path=checkpoint_path,
          eval_dict=eval_dict,
          update_op=update_op,
          global_step_tensor=global_step,
          supervisor_master=self._config.evaluation_master,
          feed_fn=feed_fn,
          max_steps=steps)

      return eval_results, current_global_step
开发者ID:Nishant23,项目名称:tensorflow,代码行数:40,代码来源:estimator.py

示例11: _evaluate_model

  def _evaluate_model(self, input_fn, steps, feed_fn=None, metrics=None):
    if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
      return

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    eval_dir = os.path.join(self._model_dir, 'eval')
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      eval_dict = self._get_eval_ops(features, targets,
                                     metrics if metrics is not None else
                                     self._get_default_metric_functions())
      eval_results, _ = evaluate(
          graph=g,
          output_dir=eval_dir,
          checkpoint_path=checkpoint_path,
          eval_dict=eval_dict,
          global_step_tensor=global_step,
          supervisor_master=self._config.master,
          feed_fn=feed_fn,
          max_steps=steps)
      return eval_results
开发者ID:277801235,项目名称:tensorflow,代码行数:24,代码来源:estimator.py

示例12: _train_model

  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True,
                   max_steps=None):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if hasattr(self._config, 'execution_mode'):
      if self._config.execution_mode not in ('all', 'train'):
        return

      # Stagger startup of worker sessions based on task id.
      sleep_secs = min(
          self._config.training_worker_max_startup_secs,
          self._config.task *
          self._config.training_worker_session_startup_stagger_secs)
      if sleep_secs:
        logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                     self._config.task)
        time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      train_op, loss_op = self._get_train_ops(features, targets)

      # Add default monitors.
      if monitors is None:
        monitors = []

      hooks = [m for m in monitors
               if isinstance(m, session_run_hook.SessionRunHook)]

      deprecated_monitors = [
          m for m in monitors
          if not isinstance(m, session_run_hook.SessionRunHook)
      ]

      supervisor_is_chief = self._config.is_chief
      if not supervisor_is_chief:
        # Prune list of monitor to the ones runnable on all workers.
        deprecated_monitors = [m for m in deprecated_monitors
                               if m.run_on_all_workers]

      # Setup monitors.
      for monitor in deprecated_monitors:
        monitor.set_estimator(self)

      if deprecated_monitors:
        hooks.append(monitor_lib.RunHookAdapterForMonitors(deprecated_monitors))

      return graph_actions._monitored_train(  # pylint: disable=protected-access
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=supervisor_is_chief,
          supervisor_master=self._config.master,
          supervisor_save_model_secs=self._config.save_checkpoints_secs,
          supervisor_save_summaries_steps=self._config.save_summary_steps,
          keep_checkpoint_max=self._config.keep_checkpoint_max,
          feed_fn=feed_fn,
          steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          hooks=hooks,
          max_steps=max_steps)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:83,代码来源:estimator.py

示例13: _train_model

  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True):
    if self._config.execution_mode not in ('all', 'train'):
      return

    # Stagger startup of worker sessions based on task id.
    sleep_secs = min(self._config.training_worker_max_startup_secs,
                     self._config.task *
                     self._config.training_worker_session_startup_stagger_secs)
    if sleep_secs:
      logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                   self._config.task)
      time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      train_op, loss_op = self._get_train_ops(features, targets)

      # Add default monitors.
      if monitors is None:
        monitors = []
      monitors += monitors_lib.get_default_monitors(
          loss_op=loss_op,
          summary_op=logging_ops.get_summary_op(),
          save_summary_steps=100,
          summary_writer=graph_actions.get_summary_writer(self._model_dir))

      is_chief = self._config.task == 0
      if not is_chief:
        # Run monitors only on chief.
        monitors = []

      # Setup monitors.
      for monitor in monitors:
        monitor.set_estimator(self)

      return train(
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=is_chief,
          supervisor_master=self._config.master,
          feed_fn=feed_fn,
          max_steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          monitors=monitors)
开发者ID:Baaaaam,项目名称:tensorflow,代码行数:68,代码来源:estimator.py

示例14: _train_model

  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True,
                   max_steps=None):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if hasattr(self._config, 'execution_mode'):
      if self._config.execution_mode not in ('all', 'train'):
        return

      # Stagger startup of worker sessions based on task id.
      sleep_secs = min(
          self._config.training_worker_max_startup_secs,
          self._config.task *
          self._config.training_worker_session_startup_stagger_secs)
      if sleep_secs:
        logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                     self._config.task)
        time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, labels = input_fn()
      self._check_inputs(features, labels)

      # The default return type of _get_train_ops is ModelFnOps. But there are
      # some subclasses of tf.contrib.learn.Estimator which override this
      # method and use the legacy signature, namely _get_train_ops returns a
      # (train_op, loss) tuple. The following else-statement code covers these
      # cases, but will soon be deleted after the subclasses are updated.
      # TODO(b/32664904): Update subclasses and delete the else-statement.
      train_ops = self._get_train_ops(features, labels)
      if isinstance(train_ops, ModelFnOps):  # Default signature
        train_op = train_ops.train_op
        loss_op = train_ops.loss
      else:  # Legacy signature
        if len(train_ops) != 2:
          raise ValueError('Expected a tuple of train_op and loss, got {}'.
                           format(train_ops))
        train_op = train_ops[0]
        loss_op = train_ops[1]

      hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)

      ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
      return graph_actions._monitored_train(  # pylint: disable=protected-access
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=self.config.is_chief,
          supervisor_master=self._config.master,
          supervisor_save_model_secs=self._config.save_checkpoints_secs,
          supervisor_save_model_steps=self._config.save_checkpoints_steps,
          supervisor_save_summaries_steps=self._config.save_summary_steps,
          keep_checkpoint_max=self._config.keep_checkpoint_max,
          feed_fn=feed_fn,
          steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          hooks=hooks,
          max_steps=max_steps)
开发者ID:chinnadhurai,项目名称:block_rnn,代码行数:78,代码来源:estimator.py


注:本文中的tensorflow.contrib.framework.create_global_step函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。