本文整理汇总了Python中tensorflow.compat.v1.get_default_session方法的典型用法代码示例。如果您正苦于以下问题:Python v1.get_default_session方法的具体用法?Python v1.get_default_session怎么用?Python v1.get_default_session使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v1
的用法示例。
在下文中一共展示了v1.get_default_session方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _load_checkpoint
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def _load_checkpoint(checkpoint_filename, extra_vars, trainable_only=False):
if tf.gfile.IsDirectory(checkpoint_filename):
checkpoint_filename = tf.train.latest_checkpoint(checkpoint_filename)
logging.info('Loading checkpoint %s', checkpoint_filename)
saveables = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
if trainable_only:
saveables = list(set(saveables) & set(tf.trainable_variables()))
# Try to restore all saveables, if that fails try without extra_vars.
try:
saver = tf.train.Saver(var_list=saveables)
saver.restore(tf.get_default_session(), checkpoint_filename)
except (ValueError, tf.errors.NotFoundError):
logging.info('Missing key in checkpoint. Trying old checkpoint format.')
saver = tf.train.Saver(var_list=list(set(saveables) - set(extra_vars)))
saver.restore(tf.get_default_session(), checkpoint_filename)
示例2: evaluate_deterministic
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def evaluate_deterministic(
model, make_data_iterator_fn, dataset_name, episodic,
num_batches_to_discard=0, temperature=1.0, print_level=2,
prediction_callback=None, extra_ops=()):
"""Evaluate with a single pass with dropout turned off."""
sum_xe = 0
sum_len = 0
num_batches = 0
last_state = None
for (cond, cond_len, source, source_len, target) in make_data_iterator_fn():
feed = _make_feed(model, cond, cond_len, source, source_len, target,
last_state, episodic, 0, temperature)
xe, last_state = tf.get_default_session().run(
[model.xe_losses, model.last_state]+list(extra_ops), feed)[0:2]
if num_batches >= num_batches_to_discard:
sum_xe1, sum_len1 = _sum_masked(source_len, xe)
sum_xe += sum_xe1
sum_len += sum_len1
if prediction_callback:
prediction_callback(target, source_len, xe)
num_batches += 1
average_xe = sum_xe / sum_len
if print_level >= 1:
logging.info('final %s xe: %6.5f (%s), batches: %s',
dataset_name, average_xe, sum_len,
num_batches-num_batches_to_discard)
return average_xe
示例3: save
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def save(self):
tf.get_default_session().run(self._save)
示例4: restore
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def restore(self):
tf.get_default_session().run(self._restore)
示例5: take_sample
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def take_sample(self):
tf.get_default_session().run(self._take_sample)
示例6: switch_to_average
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def switch_to_average(self):
tf.get_default_session().run(self._switch)
示例7: reset
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def reset(self):
tf.get_default_session().run(self._reset)
示例8: store
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def store(self, obj):
tf.get_default_session().run(self._assign_op,
feed_dict={self._new_string: str(obj)})
示例9: global_step
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def global_step(self, session=None):
if session is None:
session = tf.get_default_session()
return session.run(self.global_step_var)
示例10: fit
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def fit(self, feed, session=None):
"""Training step for observed source language example."""
if session is None:
session = tf.get_default_session()
run_options = tf.RunOptions(
report_tensor_allocations_upon_oom=True)
_, cost, summary, last_state = session.run(
[self.training_update, self.unregularized_loss, self.training_summary,
self.last_state],
feed_dict=feed, options=run_options)
return cost, summary, last_state
示例11: accumulate_gradients
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def accumulate_gradients(self, feed, session=None):
if session is None:
session = tf.get_default_session()
_, cost, summary, last_state = session.run(
[self.accumulate_grads, self.unregularized_loss,
self.training_summary, self.last_state],
feed_dict=feed)
return cost, summary, last_state
示例12: fit_accumulated
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def fit_accumulated(self, feed, session=None):
"""Training step for observed source language example."""
if session is None:
session = tf.get_default_session()
session.run([self.accumulated_training_update], feed_dict=feed)
示例13: reset_graph
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def reset_graph():
"""Closes the current default session and resets the graph."""
sess = tf.get_default_session()
if sess:
sess.close()
tf.reset_default_graph()
示例14: get_session
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def get_session(config=None):
"""Get default session or create one with a given config"""
sess = tf.get_default_session()
if sess is None:
sess = make_session(config=config, make_default=True)
return sess
示例15: get_inception_probs
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_session [as 别名]
def get_inception_probs(inps):
session=tf.get_default_session()
n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE))
preds = np.zeros([inps.shape[0], 1000], dtype = np.float32)
for i in range(n_batches):
inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1
preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000]
preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True)
return preds