当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.get_default_session方法代码示例

本文整理汇总了Python中tensorflow.get_default_session方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.get_default_session方法的具体用法?Python tensorflow.get_default_session怎么用?Python tensorflow.get_default_session使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.get_default_session方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: assign

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def assign(scopes):
    if not isinstance(scopes, list):
        scopes = [scopes]
    for scope in scopes:
        model_name = parse_scopes(scope)[0]
        try:
            __load_dict__[model_name](scope)
        except KeyError:
            try:
                tf.get_default_session().run(scope.pretrained())
            except:
                found = False
                for (key, fun) in __load_dict__.items():
                    if key in model_name.lower():
                        found = True
                        fun(scope)
                        break
                if not found:
                    warnings.warn('Random initialization will be performed '
                                  'because the pre-trained weights for ' +
                                  model_name + ' are not found.')
                    init(scope) 
开发者ID:taehoonlee,项目名称:tensornets,代码行数:24,代码来源:pretrained.py

示例2: load_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, input_map=input_map, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)
        
        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)
      
        saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file)) 
开发者ID:GaoangW,项目名称:TNT,代码行数:21,代码来源:facenet.py

示例3: get_reward

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def get_reward(self, obs, actions):
        """
        Predict the reward using the observation and action

        :param obs: (tf.Tensor or np.ndarray) the observation
        :param actions: (tf.Tensor or np.ndarray) the action
        :return: (np.ndarray) the reward
        """
        sess = tf.get_default_session()
        if len(obs.shape) == 1:
            obs = np.expand_dims(obs, 0)
        if len(actions.shape) == 1:
            actions = np.expand_dims(actions, 0)
        elif len(actions.shape) == 0:
            # one discrete action
            actions = np.expand_dims(actions, 0)

        feed_dict = {self.generator_obs_ph: obs, self.generator_acs_ph: actions}
        reward = sess.run(self.reward_op, feed_dict)
        return reward 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:22,代码来源:adversary.py

示例4: initialize

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def initialize(sess=None):
    """
    Initialize all the uninitialized variables in the global scope.

    :param sess: (TensorFlow Session)
    """
    if sess is None:
        sess = tf.get_default_session()
    new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
    sess.run(tf.variables_initializer(new_variables))
    ALREADY_INITIALIZED.update(new_variables)


# ================================================================
# Theano-like Function
# ================================================================ 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:18,代码来源:tf_util.py

示例5: __call__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def __call__(self, *args, sess=None, **kwargs):
        assert len(args) <= len(self.inputs), "Too many arguments provided"
        if sess is None:
            sess = tf.get_default_session()
        feed_dict = {}
        # Update the args
        for inpt, value in zip(self.inputs, args):
            self._feed_input(feed_dict, inpt, value)
        # Update feed dict with givens.
        for inpt in self.givens:
            feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
        results = sess.run(self.outputs_update, feed_dict=feed_dict, **kwargs)[:-1]
        return results


# ================================================================
# Flat vectors
# ================================================================ 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:20,代码来源:tf_util.py

示例6: run

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def run(*args, **kwargs): # Run the specified ops in the default session.
    return tf.get_default_session().run(*args, **kwargs) 
开发者ID:zalandoresearch,项目名称:disentangling_conditional_gans,代码行数:4,代码来源:tfutil.py

示例7: init_tf

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def init_tf(config_dict=dict()):
    if tf.get_default_session() is None:
        tf.set_random_seed(np.random.randint(1 << 31))
        create_session(config_dict, force_as_default=True)

#----------------------------------------------------------------------------
# Create tf.Session based on config dict of the form
# {'gpu_options.allow_growth': True} 
开发者ID:zalandoresearch,项目名称:disentangling_conditional_gans,代码行数:10,代码来源:tfutil.py

示例8: __getstate__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def __getstate__(self):
        sess = tf.get_default_session()
        if sess is None:
            raise RuntimeError("PicklableVariable requires a default "
                               "TensorFlow session")
        return {'var': sess.run(self.var)} 
开发者ID:StephanZheng,项目名称:neural-fingerprinting,代码行数:8,代码来源:serial.py

示例9: __setstate__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def __setstate__(self, d):
        self.var = tf.Variable(d['var'])
        sess = tf.get_default_session()
        if sess is None:
            raise RuntimeError("PicklableVariable requires a default "
                               "TensorFlow session")
        sess.run(self.var.initializer) 
开发者ID:StephanZheng,项目名称:neural-fingerprinting,代码行数:9,代码来源:serial.py

示例10: train

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def train(hps, datasets):
  """Train the LFADS model.

  Args:
    hps: The dictionary of hyperparameters.
    datasets: A dictionary of data dictionaries.  The dataset dict is simply a
      name(string)-> data dictionary mapping (See top of lfads.py).
  """
  model = build_model(hps, kind="train", datasets=datasets)
  if hps.do_reset_learning_rate:
    sess = tf.get_default_session()
    sess.run(model.learning_rate.initializer)

  model.train_model(datasets) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:16,代码来源:run_lfads.py

示例11: build

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def build(self, env, policy, q_function, q_function2, value_function,
              target_value_function):

        self._create_placeholders(env)

        policy_loss = self._policy_loss_for(policy, q_function, q_function2, value_function)
        value_function_loss = self._value_function_loss_for(
            policy, q_function, q_function2, value_function)
        q_function_loss = self._q_function_loss_for(q_function,
                                                    target_value_function)
        if q_function2 is not None:
            q_function2_loss = self._q_function_loss_for(q_function2,
                                                         target_value_function)

        optimizer = tf.train.AdamOptimizer(
            self._learning_rate, name='optimizer')
        policy_training_op = optimizer.minimize(
            loss=policy_loss, var_list=policy.trainable_variables)
        value_training_op = optimizer.minimize(
            loss=value_function_loss,
            var_list=value_function.trainable_variables)
        q_function_training_op = optimizer.minimize(
            loss=q_function_loss, var_list=q_function.trainable_variables)
        if q_function2 is not None:
            q_function2_training_op = optimizer.minimize(
                loss=q_function2_loss, var_list=q_function2.trainable_variables)

        self._training_ops = [
            policy_training_op, value_training_op, q_function_training_op
        ]
        if q_function2 is not None:
            self._training_ops += [q_function2_training_op]
        self._target_update_ops = self._create_target_update(
            source=value_function, target=target_value_function)

        tf.get_default_session().run(tf.global_variables_initializer()) 
开发者ID:xuwd11,项目名称:cs294-112_hws,代码行数:38,代码来源:sac.py

示例12: train

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def train(self, sampler, n_epochs=1000):
        """Return a generator that performs RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._start = time.time()
        for epoch in range(n_epochs):
            for t in range(self._epoch_length):
                sampler.sample()

                batch = sampler.random_batch(self._batch_size)
                feed_dict = {
                    self._observations_ph: batch['observations'],
                    self._actions_ph: batch['actions'],
                    self._next_observations_ph: batch['next_observations'],
                    self._rewards_ph: batch['rewards'],
                    self._terminals_ph: batch['terminals'],
                }
                tf.get_default_session().run(self._training_ops, feed_dict)
                tf.get_default_session().run(self._target_update_ops)

            yield epoch 
开发者ID:xuwd11,项目名称:cs294-112_hws,代码行数:29,代码来源:sac.py

示例13: get_reward

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def get_reward(self, obs, acs):
        sess = tf.get_default_session()
        if len(obs.shape) == 1:
            obs = np.expand_dims(obs, 0)
        if len(acs.shape) == 1:
            acs = np.expand_dims(acs, 0)
        feed_dict = {self.generator_obs_ph: obs, self.generator_acs_ph: acs}
        reward = sess.run(self.reward_op, feed_dict)
        return reward 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:11,代码来源:adversary.py

示例14: load_state

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def load_state(fname):
    saver = tf.train.Saver()
    saver.restore(tf.get_default_session(), fname) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:5,代码来源:utils.py

示例15: save_state

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_session [as 别名]
def save_state(fname):
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    saver = tf.train.Saver()
    saver.save(tf.get_default_session(), fname)

# ================================================================
# Placeholders
# ================================================================ 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:10,代码来源:utils.py


注:本文中的tensorflow.get_default_session方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。