当前位置: 首页>>代码示例>>Python>>正文


Python v1.get_default_session方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.get_default_session方法的典型用法代码示例。如果您正苦于以下问题:Python v1.get_default_session方法的具体用法?Python v1.get_default_session怎么用?Python v1.get_default_session使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.get_default_session方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _load_checkpoint

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def _load_checkpoint(checkpoint_filename, extra_vars, trainable_only=False):
  if tf.gfile.IsDirectory(checkpoint_filename):
    checkpoint_filename = tf.train.latest_checkpoint(checkpoint_filename)
  logging.info('Loading checkpoint %s', checkpoint_filename)
  saveables = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
               tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
  if trainable_only:
    saveables = list(set(saveables) & set(tf.trainable_variables()))
  # Try to restore all saveables, if that fails try without extra_vars.
  try:
    saver = tf.train.Saver(var_list=saveables)
    saver.restore(tf.get_default_session(), checkpoint_filename)
  except (ValueError, tf.errors.NotFoundError):
    logging.info('Missing key in checkpoint. Trying old checkpoint format.')
    saver = tf.train.Saver(var_list=list(set(saveables) - set(extra_vars)))
    saver.restore(tf.get_default_session(), checkpoint_filename) 
开发者ID:deepmind,项目名称:lamb,代码行数:18,代码来源:training.py

示例2: evaluate_deterministic

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def evaluate_deterministic(
    model, make_data_iterator_fn, dataset_name, episodic,
    num_batches_to_discard=0, temperature=1.0, print_level=2,
    prediction_callback=None, extra_ops=()):
  """Evaluate with a single pass with dropout turned off."""
  sum_xe = 0
  sum_len = 0
  num_batches = 0
  last_state = None
  for (cond, cond_len, source, source_len, target) in make_data_iterator_fn():
    feed = _make_feed(model, cond, cond_len, source, source_len, target,
                      last_state, episodic, 0, temperature)
    xe, last_state = tf.get_default_session().run(
        [model.xe_losses, model.last_state]+list(extra_ops), feed)[0:2]
    if num_batches >= num_batches_to_discard:
      sum_xe1, sum_len1 = _sum_masked(source_len, xe)
      sum_xe += sum_xe1
      sum_len += sum_len1
    if prediction_callback:
      prediction_callback(target, source_len, xe)
    num_batches += 1
  average_xe = sum_xe / sum_len
  if print_level >= 1:
    logging.info('final %s xe: %6.5f (%s), batches: %s',
                 dataset_name, average_xe, sum_len,
                 num_batches-num_batches_to_discard)
  return average_xe 
开发者ID:deepmind,项目名称:lamb,代码行数:29,代码来源:evaluation.py

示例3: save

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def save(self):
    tf.get_default_session().run(self._save) 
开发者ID:deepmind,项目名称:lamb,代码行数:4,代码来源:dyneval.py

示例4: restore

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def restore(self):
    tf.get_default_session().run(self._restore) 
开发者ID:deepmind,项目名称:lamb,代码行数:4,代码来源:dyneval.py

示例5: take_sample

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def take_sample(self):
    tf.get_default_session().run(self._take_sample) 
开发者ID:deepmind,项目名称:lamb,代码行数:4,代码来源:averaged.py

示例6: switch_to_average

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def switch_to_average(self):
    tf.get_default_session().run(self._switch) 
开发者ID:deepmind,项目名称:lamb,代码行数:4,代码来源:averaged.py

示例7: reset

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def reset(self):
    tf.get_default_session().run(self._reset) 
开发者ID:deepmind,项目名称:lamb,代码行数:4,代码来源:averaged.py

示例8: store

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def store(self, obj):
    tf.get_default_session().run(self._assign_op,
                                 feed_dict={self._new_string: str(obj)}) 
开发者ID:deepmind,项目名称:lamb,代码行数:5,代码来源:utils.py

示例9: global_step

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def global_step(self, session=None):
    if session is None:
      session = tf.get_default_session()
    return session.run(self.global_step_var) 
开发者ID:deepmind,项目名称:lamb,代码行数:6,代码来源:lm.py

示例10: fit

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def fit(self, feed, session=None):
    """Training step for observed source language example."""
    if session is None:
      session = tf.get_default_session()
    run_options = tf.RunOptions(
        report_tensor_allocations_upon_oom=True)
    _, cost, summary, last_state = session.run(
        [self.training_update, self.unregularized_loss, self.training_summary,
         self.last_state],
        feed_dict=feed, options=run_options)
    return cost, summary, last_state 
开发者ID:deepmind,项目名称:lamb,代码行数:13,代码来源:lm.py

示例11: accumulate_gradients

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def accumulate_gradients(self, feed, session=None):
    if session is None:
      session = tf.get_default_session()
    _, cost, summary, last_state = session.run(
        [self.accumulate_grads, self.unregularized_loss,
         self.training_summary, self.last_state],
        feed_dict=feed)
    return cost, summary, last_state 
开发者ID:deepmind,项目名称:lamb,代码行数:10,代码来源:lm.py

示例12: fit_accumulated

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def fit_accumulated(self, feed, session=None):
    """Training step for observed source language example."""
    if session is None:
      session = tf.get_default_session()
    session.run([self.accumulated_training_update], feed_dict=feed) 
开发者ID:deepmind,项目名称:lamb,代码行数:7,代码来源:lm.py

示例13: reset_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def reset_graph():
  """Closes the current default session and resets the graph."""
  sess = tf.get_default_session()
  if sess:
    sess.close()
  tf.reset_default_graph() 
开发者ID:magenta,项目名称:magenta,代码行数:8,代码来源:sketch_rnn_train.py

示例14: get_session

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def get_session(config=None):
    """Get default session or create one with a given config"""
    sess = tf.get_default_session()
    if sess is None:
        sess = make_session(config=config, make_default=True)
    return sess 
开发者ID:microsoft,项目名称:nni,代码行数:8,代码来源:util.py

示例15: get_inception_probs

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def get_inception_probs(inps):
    session=tf.get_default_session()
    n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE))
    preds = np.zeros([inps.shape[0], 1000], dtype = np.float32)
    for i in range(n_batches):
        inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1
        preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000]
    preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
    return preds 
开发者ID:tsc2017,项目名称:Inception-Score,代码行数:11,代码来源:inception_score.py


注:本文中的tensorflow.compat.v1.get_default_session方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。