当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.get_default_session函数代码示例

本文整理汇总了Python中tensorflow.get_default_session函数的典型用法代码示例。如果您正苦于以下问题:Python get_default_session函数的具体用法?Python get_default_session怎么用?Python get_default_session使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_default_session函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_session

def get_session():
    """Returns the TF session to be used by the backend.

    If a default TensorFlow session is available, we will return it.

    Else, we will return the global Keras session.

    If no global Keras session exists at this point:
    we will create a new global session.

    Note that you can manually set the global session
    via `K.set_session(sess)`.
    """
    global _SESSION
    if tf.get_default_session() is not None:
        return tf.get_default_session()
    if _SESSION is None:
        if not os.environ.get("OMP_NUM_THREADS"):
            _SESSION = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        else:
            nb_thread = int(os.environ.get("OMP_NUM_THREADS"))
            _SESSION = tf.Session(
                config=tf.ConfigProto(intra_op_parallelism_threads=nb_thread, allow_soft_placement=True)
            )
    return _SESSION
开发者ID:faroit,项目名称:keras,代码行数:25,代码来源:tensorflow_backend.py

示例2: train

 def train(self, obs, actions, gaes, rewards, v_preds_next):
     tf.get_default_session().run(self.train_op, feed_dict={self.Policy.obs: obs,
                                                            self.Old_Policy.obs: obs,
                                                            self.actions: actions,
                                                            self.rewards: rewards,
                                                            self.v_preds_next: v_preds_next,
                                                            self.gaes: gaes})
开发者ID:6-Billionaires,项目名称:gail_ppo_optimizer,代码行数:7,代码来源:ppo.py

示例3: train_step

    def train_step(self, cases, weights, caching):
        if len(cases) != len(weights):
            raise ValueError('cases and weights must have the same length.')

        if len(cases) == 0:
            #logging.warn('Training on zero cases.')
            print >> sys.stderr, " WARNING: Zero cases   \033[F"
            # still increment the step
            sess = tf.get_default_session()
            sess.run(self._increment_step)
        elif not self._max_batch_size or len(cases) <= self._max_batch_size:
            print >> sys.stderr, " Updating ({} cases)   \033[F".format(len(cases))
            self.compute(self._take_step, cases, weights, caching)
        else:
            print >> sys.stderr, " Updating ({} cases)   \033[F".format(len(cases))
            assert not caching
            grads = None
            slices = range(0, len(cases), self._max_batch_size)
            for i in verboserate(slices, desc='Computing gradients ({} cases)'.format(len(cases))):
                cases_slice = cases[i:i + self._max_batch_size]
                weights_slice = weights[i:i + self._max_batch_size]
                grads_slice = self.compute(self._grad_tensors,
                                           cases_slice, weights_slice, False)
                if grads is None:
                    grads = grads_slice
                else:
                    for i in xrange(len(self._grad_tensors)):
                        grads[i] += grads_slice[i]
            sess = tf.get_default_session()
            feed_dict = dict(zip(self._combined_grad_placeholders, grads))
            sess.run(self._apply_gradients, feed_dict)
            sess.run(self._increment_step)
开发者ID:siddk,项目名称:lang2program,代码行数:32,代码来源:parse_model.py

示例4: main

def main(args):
    with tf.Graph().as_default():
        with tf.Session() as sess:
            # Load the model metagraph and checkpoint
            print('Model directory: %s' % args.model_dir)
            meta_file, ckpt_file = facenet.get_model_filenames(os.path.expanduser(args.model_dir))
            
            print('Metagraph file: %s' % meta_file)
            print('Checkpoint file: %s' % ckpt_file)

            model_dir_exp = os.path.expanduser(args.model_dir)
            saver = tf.train.import_meta_graph(os.path.join(model_dir_exp, meta_file), clear_devices=True)
            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), os.path.join(model_dir_exp, ckpt_file))
            
            # Retrieve the protobuf graph definition and fix the batch norm nodes
            input_graph_def = sess.graph.as_graph_def()
            
            # Freeze the graph def
            output_graph_def = freeze_graph_def(sess, input_graph_def, 'embeddings')

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(args.output_file, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph: %s" % (len(output_graph_def.node), args.output_file))
开发者ID:citysir,项目名称:facenet,代码行数:26,代码来源:freeze_graph.py

示例5: get_session

def get_session():
  """Get the globally defined TensorFlow session.

  If the session is not already defined, then the function will create
  a global session.

  Returns:
    _ED_SESSION: tf.InteractiveSession.
  """
  global _ED_SESSION
  if tf.get_default_session() is None:
    _ED_SESSION = tf.InteractiveSession()
  else:
    _ED_SESSION = tf.get_default_session()

  save_stderr = sys.stderr
  try:
    import os
    sys.stderr = open(os.devnull, 'w')  # suppress keras import
    from keras import backend as K
    sys.stderr = save_stderr
    have_keras = True
  except ImportError:
    sys.stderr = save_stderr
    have_keras = False
  if have_keras:
    K.set_session(_ED_SESSION)

  return _ED_SESSION
开发者ID:JoyceYa,项目名称:edward,代码行数:29,代码来源:graphs.py

示例6: fit

 def fit(self, xs, ys):
     if self.normalize_inputs:
         # recompute normalizing constants for inputs
         new_mean = np.mean(xs, axis=0, keepdims=True)
         new_std = np.std(xs, axis=0, keepdims=True) + 1e-8
         tf.get_default_session().run(tf.group(
             tf.assign(self.x_mean_var, new_mean),
             tf.assign(self.x_std_var, new_std),
         ))
     if self.use_trust_region and self.first_optimized:
         old_prob = self.f_prob(xs)
         inputs = [xs, ys, old_prob]
         optimizer = self.tr_optimizer
     else:
         inputs = [xs, ys]
         optimizer = self.optimizer
     loss_before = optimizer.loss(inputs)
     if self.name:
         prefix = self.name + "_"
     else:
         prefix = ""
     logger.record_tabular(prefix + 'LossBefore', loss_before)
     optimizer.optimize(inputs)
     loss_after = optimizer.loss(inputs)
     logger.record_tabular(prefix + 'LossAfter', loss_after)
     logger.record_tabular(prefix + 'dLoss', loss_before - loss_after)
     self.first_optimized = True
开发者ID:flyers,项目名称:rllab,代码行数:27,代码来源:categorical_mlp_regressor.py

示例7: test_lookup_activations

 def test_lookup_activations(self):
     x = tf.constant(-1.0, shape=[2, 2])
     with self.test_session():
         activations = ['relu','prelu','selu','crelu']
         for activation in activations:
             activation = ops.lookup(activation)(x)
             
             tf.get_default_session().run(tf.global_variables_initializer())
         
             self.assertNotEqual(x.eval()[0][0], activation.eval()[0][0])
开发者ID:255BITS,项目名称:hyperchamber-gan,代码行数:10,代码来源:ops_test.py

示例8: restore_trainer

    def restore_trainer(self, filename):
        '''
        Load the training progress (including the model)

        Args:
            filename: path where the model will be saved
        '''

        self.modelsaver.restore(tf.get_default_session(), filename)
        self.saver.restore(tf.get_default_session(), filename + '_trainvars')
开发者ID:vrenkens,项目名称:tfkaldi,代码行数:10,代码来源:trainer.py

示例9: test_logging_trainable

 def test_logging_trainable(self):
   with tf.Graph().as_default() as g, self.test_session(g):
     var = tf.Variable(tf.constant(42.0), name='foo')
     var.initializer.run()
     cof = tf.constant(1.0)
     loss = tf.sub(tf.mul(var, cof), tf.constant(1.0))
     train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
     tf.get_default_session().run(train_step)
     self._run_monitor(learn.monitors.LoggingTrainable('foo'))
     self.assertRegexpMatches(str(self.logged_message), var.name)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:10,代码来源:monitors_test.py

示例10: fit

    def fit(self, paths, policy=None, batch_size=32, max_itrs=100, logger=None, lr=1e-3,**kwargs):
        #self._compute_path_probs(paths, insert=True)
        self.eval_expert_probs(paths, policy, insert=True)
        self.eval_expert_probs(self.expert_trajs, policy, insert=True)
        obs, acts, path_probs = self.extract_paths(paths, keys=('observations', 'actions', 'a_logprobs'))
        expert_obs, expert_acts, expert_probs = self.extract_paths(self.expert_trajs, keys=('observations', 'actions', 'a_logprobs'))

        # Train discriminator
        for it in TrainingIterator(max_itrs, heartbeat=5):
            obs_batch, act_batch, lprobs_batch = \
                self.sample_batch(obs, acts, path_probs, batch_size=batch_size)

            expert_obs_batch, expert_act_batch, expert_lprobs_batch = \
                self.sample_batch(expert_obs, expert_acts, expert_probs, batch_size=batch_size)

            labels = np.zeros((batch_size*2, 1))
            labels[batch_size:] = 1.0
            obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
            act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
            lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)

            loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={
                self.act_t: act_batch,
                self.obs_t: obs_batch,
                self.labels: labels,
                self.lprobs: lprobs_batch,
                self.lr: lr
            })

            it.record('loss', loss)
            if it.heartbeat:
                print(it.itr_message())
                mean_loss = it.pop_mean('loss')
                print('\tLoss:%f' % mean_loss)
        if logger:
            energy, logZ, dtau = tf.get_default_session().run([self.energy, self.value_fn, self.d_tau],
                                                        feed_dict={self.act_t: acts, self.obs_t: obs,
                                                                   self.lprobs: np.expand_dims(path_probs, axis=1)})
            logger.record_tabular('IRLLogZ', np.mean(logZ))
            logger.record_tabular('IRLAverageEnergy', np.mean(energy))
            logger.record_tabular('IRLAverageLogPtau', np.mean(-energy-logZ))
            logger.record_tabular('IRLAverageLogQtau', np.mean(path_probs))
            logger.record_tabular('IRLMedianLogQtau', np.median(path_probs))
            logger.record_tabular('IRLAverageDtau', np.mean(dtau))

            energy, logZ, dtau = tf.get_default_session().run([self.energy, self.value_fn, self.d_tau],
                                                        feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs,
                                                                   self.lprobs: np.expand_dims(expert_probs, axis=1)})
            logger.record_tabular('IRLAverageExpertEnergy', np.mean(energy))
            logger.record_tabular('IRLAverageExpertLogPtau', np.mean(-energy-logZ))
            logger.record_tabular('IRLAverageExpertLogQtau', np.mean(expert_probs))
            logger.record_tabular('IRLMedianExpertLogQtau', np.median(expert_probs))
            logger.record_tabular('IRLAverageExpertDtau', np.mean(dtau))
        return mean_loss
开发者ID:saadmahboob,项目名称:inverse_rl,代码行数:54,代码来源:imitation_learning.py

示例11: get_session

def get_session():
    global _session

    # Build/retrieve the session if it doesn't exist
    if _session is None:
        if tf.get_default_session() is not None:
            _session = tf.get_default_session()
        else:
            _session = tf.Session()

    return _session
开发者ID:CloudBreadPaPa,项目名称:tensorrec,代码行数:11,代码来源:session_management.py

示例12: test_preserves_existing_session

  def test_preserves_existing_session(self):
    with tf.Session() as sess:
      op = tf.reduce_sum([2, 2])
      self.assertIs(sess, tf.get_default_session())

      result = self._square(123)
      self.assertEqual(123 * 123, result)

      self.assertIs(sess, tf.get_default_session())
      number_of_lights = sess.run(op)
      self.assertEqual(number_of_lights, 4)
开发者ID:jlewi,项目名称:tensorboard,代码行数:11,代码来源:util_test.py

示例13: zero_model_gradient_accumulators

    def zero_model_gradient_accumulators(cls) -> None:
        zero_operations = [
            tf.get_default_graph().get_operation_by_name(
                '{}/zero_model_gradient_accumulators'.format(
                    variable_scope_name))
            for variable_scope_name in [
                'empty_statistic',
                'move_rate',
                'game_state_as_update',
                'updated_statistic',
                'updated_update',
                'cost_function']]

        tf.get_default_session().run(zero_operations)
开发者ID:thomasste,项目名称:ugtsa,代码行数:14,代码来源:model_builder.py

示例14: predict_with_three_models_on_hashtags

def predict_with_three_models_on_hashtags(hashtag_dir, hashtag_emb_dir, trial_hashtag_names, labels_exist=True):
    # eval_hashtag_names = get_hashtag_file_names(SEMEVAL_HUMOR_EVAL_DIR)
    emb_char_predictions = []
    emb_predictions = []
    char_predictions = []
    per_hashtag_first_tweet_ids = []
    per_hashtag_second_tweet_ids = []
    K.clear_session()
    K.set_session(tf.get_default_session())
    hp1 = humor_predictor.HumorPredictor(EMB_CHAR_HUMOR_MODEL_DIR, use_emb_model=True, use_char_model=True)
    for trial_hashtag_name in trial_hashtag_names:
        np_predictions, np_output_prob, np_labels, first_tweet_ids, second_tweet_ids = hp1(hashtag_dir,
                                                                                          trial_hashtag_name)
        emb_char_predictions.append(np_output_prob)
        per_hashtag_first_tweet_ids.append(first_tweet_ids)
        per_hashtag_second_tweet_ids.append(second_tweet_ids)

    K.clear_session()
    K.set_session(tf.get_default_session())
    hp2 = humor_predictor.HumorPredictor(EMB_HUMOR_MODEL_DIR, use_emb_model=True, use_char_model=False)
    for trial_hashtag_name in trial_hashtag_names:
        np_predictions, np_output_prob, np_labels, first_tweet_ids, second_tweet_ids = hp2(hashtag_dir,
                                                                                          trial_hashtag_name)
        emb_predictions.append(np_output_prob)

    K.clear_session()
    K.set_session(tf.get_default_session())
    hp3 = humor_predictor.HumorPredictor(CHAR_HUMOR_MODEL_DIR, use_emb_model=False, use_char_model=True)

    for trial_hashtag_name in trial_hashtag_names:
        np_predictions, np_output_prob, np_labels, first_tweet_ids, second_tweet_ids = hp3(hashtag_dir,
                                                                                       trial_hashtag_name)
        char_predictions.append(np_output_prob)

    all_predictions = []
    for i in range(len(trial_hashtag_names)):
        hashtag_all_predictions = np.concatenate(
            [np.reshape(emb_char_predictions[i], [-1, 1]), np.reshape(emb_predictions[i], [-1, 1]), np.reshape(char_predictions[i], [-1, 1])], axis=1)
        all_predictions.append(hashtag_all_predictions)

    hashtag_labels = None
    if labels_exist:
        hashtag_labels = []
        for hashtag_name in trial_hashtag_names:
            print 'Loading label for hashtag %s' % hashtag_name
            np_first_tweets, np_second_tweets, np_labels, first_tweet_ids, second_tweet_ids, np_hashtag = \
                load_hashtag_data(hashtag_emb_dir, hashtag_name)
            hashtag_labels.append(np_labels)

    return all_predictions, hashtag_labels, per_hashtag_first_tweet_ids, per_hashtag_second_tweet_ids
开发者ID:text-machine-lab,项目名称:ht_wars,代码行数:50,代码来源:humor_ensemble_processing2.py

示例15: main

def main(args):
    with tf.Graph().as_default():
        with tf.Session() as sess:
            # Load the model metagraph and checkpoint
            print('Model directory: %s' % args.model_dir)
            meta_file, ckpt_file = facenet.get_model_filenames(os.path.expanduser(args.model_dir))
            
            print('Metagraph file: %s' % meta_file)
            print('Checkpoint file: %s' % ckpt_file)

            model_dir_exp = os.path.expanduser(args.model_dir)
            saver = tf.train.import_meta_graph(os.path.join(model_dir_exp, meta_file), clear_devices=True)
            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), os.path.join(model_dir_exp, ckpt_file))
            
            # Retrieve the protobuf graph definition and fix the batch norm nodes
            gd = sess.graph.as_graph_def()
            for node in gd.node:            
                if node.op == 'RefSwitch':
                    node.op = 'Switch'
                    for index in xrange(len(node.input)):
                        if 'moving_' in node.input[index]:
                            node.input[index] = node.input[index] + '/read'
                elif node.op == 'AssignSub':
                    node.op = 'Sub'
                    if 'use_locking' in node.attr: del node.attr['use_locking']
                elif node.op == 'AssignAdd':
                    node.op = 'Add'
                    if 'use_locking' in node.attr: del node.attr['use_locking']
            
            # Get the list of important nodes
            output_node_names = 'embeddings'
            whitelist_names = []
            for node in gd.node:
                if node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or node.name.startswith('phase_train'):
                    print(node.name)
                    whitelist_names.append(node.name)

            # Replace all the variables in the graph with constants of the same values
            output_graph_def = graph_util.convert_variables_to_constants(
                sess, gd, output_node_names.split(","),
                variable_names_whitelist=whitelist_names)

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(args.output_file, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
开发者ID:billtiger,项目名称:CATANA,代码行数:48,代码来源:freeze_graph.py


注:本文中的tensorflow.get_default_session函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。