本文整理汇总了Python中tensorflow.get_default_session方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.get_default_session方法的具体用法?Python tensorflow.get_default_session怎么用?Python tensorflow.get_default_session使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.get_default_session方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: assign
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def assign(scopes):
if not isinstance(scopes, list):
scopes = [scopes]
for scope in scopes:
model_name = parse_scopes(scope)[0]
try:
__load_dict__[model_name](scope)
except KeyError:
try:
tf.get_default_session().run(scope.pretrained())
except:
found = False
for (key, fun) in __load_dict__.items():
if key in model_name.lower():
found = True
fun(scope)
break
if not found:
warnings.warn('Random initialization will be performed '
'because the pre-trained weights for ' +
model_name + ' are not found.')
init(scope)
示例2: load_model
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def load_model(model, input_map=None):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with gfile.FastGFile(model_exp,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, input_map=input_map, name='')
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
示例3: get_reward
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def get_reward(self, obs, actions):
"""
Predict the reward using the observation and action
:param obs: (tf.Tensor or np.ndarray) the observation
:param actions: (tf.Tensor or np.ndarray) the action
:return: (np.ndarray) the reward
"""
sess = tf.get_default_session()
if len(obs.shape) == 1:
obs = np.expand_dims(obs, 0)
if len(actions.shape) == 1:
actions = np.expand_dims(actions, 0)
elif len(actions.shape) == 0:
# one discrete action
actions = np.expand_dims(actions, 0)
feed_dict = {self.generator_obs_ph: obs, self.generator_acs_ph: actions}
reward = sess.run(self.reward_op, feed_dict)
return reward
示例4: initialize
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def initialize(sess=None):
"""
Initialize all the uninitialized variables in the global scope.
:param sess: (TensorFlow Session)
"""
if sess is None:
sess = tf.get_default_session()
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
sess.run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
# ================================================================
# Theano-like Function
# ================================================================
示例5: __call__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def __call__(self, *args, sess=None, **kwargs):
assert len(args) <= len(self.inputs), "Too many arguments provided"
if sess is None:
sess = tf.get_default_session()
feed_dict = {}
# Update the args
for inpt, value in zip(self.inputs, args):
self._feed_input(feed_dict, inpt, value)
# Update feed dict with givens.
for inpt in self.givens:
feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
results = sess.run(self.outputs_update, feed_dict=feed_dict, **kwargs)[:-1]
return results
# ================================================================
# Flat vectors
# ================================================================
示例6: run
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def run(*args, **kwargs): # Run the specified ops in the default session.
return tf.get_default_session().run(*args, **kwargs)
示例7: init_tf
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def init_tf(config_dict=dict()):
if tf.get_default_session() is None:
tf.set_random_seed(np.random.randint(1 << 31))
create_session(config_dict, force_as_default=True)
#----------------------------------------------------------------------------
# Create tf.Session based on config dict of the form
# {'gpu_options.allow_growth': True}
示例8: __getstate__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def __getstate__(self):
sess = tf.get_default_session()
if sess is None:
raise RuntimeError("PicklableVariable requires a default "
"TensorFlow session")
return {'var': sess.run(self.var)}
示例9: __setstate__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def __setstate__(self, d):
self.var = tf.Variable(d['var'])
sess = tf.get_default_session()
if sess is None:
raise RuntimeError("PicklableVariable requires a default "
"TensorFlow session")
sess.run(self.var.initializer)
示例10: train
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def train(hps, datasets):
"""Train the LFADS model.
Args:
hps: The dictionary of hyperparameters.
datasets: A dictionary of data dictionaries. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
model = build_model(hps, kind="train", datasets=datasets)
if hps.do_reset_learning_rate:
sess = tf.get_default_session()
sess.run(model.learning_rate.initializer)
model.train_model(datasets)
示例11: build
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def build(self, env, policy, q_function, q_function2, value_function,
target_value_function):
self._create_placeholders(env)
policy_loss = self._policy_loss_for(policy, q_function, q_function2, value_function)
value_function_loss = self._value_function_loss_for(
policy, q_function, q_function2, value_function)
q_function_loss = self._q_function_loss_for(q_function,
target_value_function)
if q_function2 is not None:
q_function2_loss = self._q_function_loss_for(q_function2,
target_value_function)
optimizer = tf.train.AdamOptimizer(
self._learning_rate, name='optimizer')
policy_training_op = optimizer.minimize(
loss=policy_loss, var_list=policy.trainable_variables)
value_training_op = optimizer.minimize(
loss=value_function_loss,
var_list=value_function.trainable_variables)
q_function_training_op = optimizer.minimize(
loss=q_function_loss, var_list=q_function.trainable_variables)
if q_function2 is not None:
q_function2_training_op = optimizer.minimize(
loss=q_function2_loss, var_list=q_function2.trainable_variables)
self._training_ops = [
policy_training_op, value_training_op, q_function_training_op
]
if q_function2 is not None:
self._training_ops += [q_function2_training_op]
self._target_update_ops = self._create_target_update(
source=value_function, target=target_value_function)
tf.get_default_session().run(tf.global_variables_initializer())
示例12: train
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def train(self, sampler, n_epochs=1000):
"""Return a generator that performs RL training.
Args:
env (`rllab.Env`): Environment used for training
policy (`Policy`): Policy used for training
initial_exploration_policy ('Policy'): Policy used for exploration
If None, then all exploration is done using policy
pool (`PoolBase`): Sample pool to add samples to
"""
self._start = time.time()
for epoch in range(n_epochs):
for t in range(self._epoch_length):
sampler.sample()
batch = sampler.random_batch(self._batch_size)
feed_dict = {
self._observations_ph: batch['observations'],
self._actions_ph: batch['actions'],
self._next_observations_ph: batch['next_observations'],
self._rewards_ph: batch['rewards'],
self._terminals_ph: batch['terminals'],
}
tf.get_default_session().run(self._training_ops, feed_dict)
tf.get_default_session().run(self._target_update_ops)
yield epoch
示例13: get_reward
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def get_reward(self, obs, acs):
sess = tf.get_default_session()
if len(obs.shape) == 1:
obs = np.expand_dims(obs, 0)
if len(acs.shape) == 1:
acs = np.expand_dims(acs, 0)
feed_dict = {self.generator_obs_ph: obs, self.generator_acs_ph: acs}
reward = sess.run(self.reward_op, feed_dict)
return reward
示例14: load_state
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def load_state(fname):
saver = tf.train.Saver()
saver.restore(tf.get_default_session(), fname)
示例15: save_state
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def save_state(fname):
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), fname)
# ================================================================
# Placeholders
# ================================================================