當前位置: 首頁>>代碼示例>>Python>>正文


Python v1.get_default_session方法代碼示例

本文整理匯總了Python中tensorflow.compat.v1.get_default_session方法的典型用法代碼示例。如果您正苦於以下問題:Python v1.get_default_session方法的具體用法?Python v1.get_default_session怎麽用?Python v1.get_default_session使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.get_default_session方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: _load_checkpoint

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def _load_checkpoint(checkpoint_filename, extra_vars, trainable_only=False):
  if tf.gfile.IsDirectory(checkpoint_filename):
    checkpoint_filename = tf.train.latest_checkpoint(checkpoint_filename)
  logging.info('Loading checkpoint %s', checkpoint_filename)
  saveables = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
               tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
  if trainable_only:
    saveables = list(set(saveables) & set(tf.trainable_variables()))
  # Try to restore all saveables, if that fails try without extra_vars.
  try:
    saver = tf.train.Saver(var_list=saveables)
    saver.restore(tf.get_default_session(), checkpoint_filename)
  except (ValueError, tf.errors.NotFoundError):
    logging.info('Missing key in checkpoint. Trying old checkpoint format.')
    saver = tf.train.Saver(var_list=list(set(saveables) - set(extra_vars)))
    saver.restore(tf.get_default_session(), checkpoint_filename) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:18,代碼來源:training.py

示例2: evaluate_deterministic

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def evaluate_deterministic(
    model, make_data_iterator_fn, dataset_name, episodic,
    num_batches_to_discard=0, temperature=1.0, print_level=2,
    prediction_callback=None, extra_ops=()):
  """Evaluate with a single pass with dropout turned off."""
  sum_xe = 0
  sum_len = 0
  num_batches = 0
  last_state = None
  for (cond, cond_len, source, source_len, target) in make_data_iterator_fn():
    feed = _make_feed(model, cond, cond_len, source, source_len, target,
                      last_state, episodic, 0, temperature)
    xe, last_state = tf.get_default_session().run(
        [model.xe_losses, model.last_state]+list(extra_ops), feed)[0:2]
    if num_batches >= num_batches_to_discard:
      sum_xe1, sum_len1 = _sum_masked(source_len, xe)
      sum_xe += sum_xe1
      sum_len += sum_len1
    if prediction_callback:
      prediction_callback(target, source_len, xe)
    num_batches += 1
  average_xe = sum_xe / sum_len
  if print_level >= 1:
    logging.info('final %s xe: %6.5f (%s), batches: %s',
                 dataset_name, average_xe, sum_len,
                 num_batches-num_batches_to_discard)
  return average_xe 
開發者ID:deepmind,項目名稱:lamb,代碼行數:29,代碼來源:evaluation.py

示例3: save

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def save(self):
    tf.get_default_session().run(self._save) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:4,代碼來源:dyneval.py

示例4: restore

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def restore(self):
    tf.get_default_session().run(self._restore) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:4,代碼來源:dyneval.py

示例5: take_sample

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def take_sample(self):
    tf.get_default_session().run(self._take_sample) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:4,代碼來源:averaged.py

示例6: switch_to_average

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def switch_to_average(self):
    tf.get_default_session().run(self._switch) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:4,代碼來源:averaged.py

示例7: reset

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def reset(self):
    tf.get_default_session().run(self._reset) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:4,代碼來源:averaged.py

示例8: store

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def store(self, obj):
    tf.get_default_session().run(self._assign_op,
                                 feed_dict={self._new_string: str(obj)}) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:5,代碼來源:utils.py

示例9: global_step

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def global_step(self, session=None):
    if session is None:
      session = tf.get_default_session()
    return session.run(self.global_step_var) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:6,代碼來源:lm.py

示例10: fit

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def fit(self, feed, session=None):
    """Training step for observed source language example."""
    if session is None:
      session = tf.get_default_session()
    run_options = tf.RunOptions(
        report_tensor_allocations_upon_oom=True)
    _, cost, summary, last_state = session.run(
        [self.training_update, self.unregularized_loss, self.training_summary,
         self.last_state],
        feed_dict=feed, options=run_options)
    return cost, summary, last_state 
開發者ID:deepmind,項目名稱:lamb,代碼行數:13,代碼來源:lm.py

示例11: accumulate_gradients

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def accumulate_gradients(self, feed, session=None):
    if session is None:
      session = tf.get_default_session()
    _, cost, summary, last_state = session.run(
        [self.accumulate_grads, self.unregularized_loss,
         self.training_summary, self.last_state],
        feed_dict=feed)
    return cost, summary, last_state 
開發者ID:deepmind,項目名稱:lamb,代碼行數:10,代碼來源:lm.py

示例12: fit_accumulated

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def fit_accumulated(self, feed, session=None):
    """Training step for observed source language example."""
    if session is None:
      session = tf.get_default_session()
    session.run([self.accumulated_training_update], feed_dict=feed) 
開發者ID:deepmind,項目名稱:lamb,代碼行數:7,代碼來源:lm.py

示例13: reset_graph

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def reset_graph():
  """Closes the current default session and resets the graph."""
  sess = tf.get_default_session()
  if sess:
    sess.close()
  tf.reset_default_graph() 
開發者ID:magenta,項目名稱:magenta,代碼行數:8,代碼來源:sketch_rnn_train.py

示例14: get_session

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def get_session(config=None):
    """Get default session or create one with a given config"""
    sess = tf.get_default_session()
    if sess is None:
        sess = make_session(config=config, make_default=True)
    return sess 
開發者ID:microsoft,項目名稱:nni,代碼行數:8,代碼來源:util.py

示例15: get_inception_probs

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 別名]
def get_inception_probs(inps):
    session=tf.get_default_session()
    n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE))
    preds = np.zeros([inps.shape[0], 1000], dtype = np.float32)
    for i in range(n_batches):
        inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1
        preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000]
    preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
    return preds 
開發者ID:tsc2017,項目名稱:Inception-Score,代碼行數:11,代碼來源:inception_score.py


注:本文中的tensorflow.compat.v1.get_default_session方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。