當前位置: 首頁>>代碼示例>>Python>>正文


Python backend.set_session方法代碼示例

本文整理匯總了Python中tensorflow.keras.backend.set_session方法的典型用法代碼示例。如果您正苦於以下問題:Python backend.set_session方法的具體用法?Python backend.set_session怎麽用?Python backend.set_session使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.keras.backend的用法示例。


在下文中一共展示了backend.set_session方法的9個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            inp = None
            output = None
            if self.shared_network is None:
                inp = Input((self.input_dim,))
                output = self.get_network_head(inp).output
            else:
                inp = self.shared_network.input
                output = self.shared_network.output
            output = Dense(
                self.output_dim, activation=self.activation, 
                kernel_initializer='random_normal')(output)
            self.model = Model(inp, output)
            self.model.compile(
                optimizer=SGD(lr=self.lr), loss=self.loss) 
開發者ID:quantylab,項目名稱:rltrader,代碼行數:21,代碼來源:networks.py

示例2: cpu_config

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def cpu_config(first=False):
    # intel optimizations
    num_cores, num_sockets = get_cpuinfo()
    if first:
        print("system info::")
        print("Number of physical cores:: ", num_cores)
        print("Number of sockets::", num_sockets)
    backend.set_session(
        tf.Session(
            config=tf.ConfigProto(
                intra_op_parallelism_threads=num_cores,
                inter_op_parallelism_threads=num_sockets,
            )
        )
    )
###########################################################
# Training
########################################################### 
開發者ID:intel,項目名稱:stacks-usecase,代碼行數:20,代碼來源:main.py

示例3: set_session

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def set_session(sess): pass 
開發者ID:quantylab,項目名稱:rltrader,代碼行數:3,代碼來源:networks.py

示例4: predict

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def predict(self, sample):
        with self.lock:
            with graph.as_default():
                if sess is not None:
                    set_session(sess)
                return self.model.predict(sample).flatten() 
開發者ID:quantylab,項目名稱:rltrader,代碼行數:8,代碼來源:networks.py

示例5: train_on_batch

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def train_on_batch(self, x, y):
        loss = 0.
        with self.lock:
            with graph.as_default():
                if sess is not None:
                    set_session(sess)
                loss = self.model.train_on_batch(x, y)
        return loss 
開發者ID:quantylab,項目名稱:rltrader,代碼行數:10,代碼來源:networks.py

示例6: get_shared_network

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def get_shared_network(cls, net='dnn', num_steps=1, input_dim=0):
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            if net == 'dnn':
                return DNN.get_network_head(Input((input_dim,)))
            elif net == 'lstm':
                return LSTMNetwork.get_network_head(
                    Input((num_steps, input_dim)))
            elif net == 'cnn':
                return CNN.get_network_head(
                    Input((1, num_steps, input_dim))) 
開發者ID:quantylab,項目名稱:rltrader,代碼行數:14,代碼來源:networks.py

示例7: _run

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def _run(FLAGS):
  hparams = init_hparams(FLAGS)
  init_random_seeds(hparams)

  for run in range(hparams.copies):
    log_start_of_run(FLAGS, hparams, run)

    with tf.Session() as sess:
      K.set_session(sess)
      agent, checkpoint = init_agent(sess, hparams)

      restored = checkpoint.restore()
      if not restored:
        sess.run(tf.global_variables_initializer())

      if not hparams.test_only:
        log_graph()

        agent.clone_weights()

        if hparams.num_workers == 1:
          train(0, agent, hparams, checkpoint)
        else:
          workers = [
              threading.Thread(
                  target=train, args=(worker_id, agent, hparams, checkpoint))
              for worker_id in range(hparams.num_workers)
          ]

          for worker in workers:
            worker.start()

          for worker in workers:
            worker.join()
      else:
        test(hparams, agent)

    hparams = init_hparams(FLAGS) 
開發者ID:for-ai,項目名稱:rl,代碼行數:40,代碼來源:train.py

示例8: configure_gpus

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def configure_gpus(gpus):
    # set gpu id and tf settings
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(g) for g in gpus])

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    K.set_session(tf.Session(config=config))


# loads a saved experiment using the saved parameters.
# runs all initialization steps so that we can use the models right away 
開發者ID:xamyzhao,項目名稱:brainstorm,代碼行數:14,代碼來源:experiment_engine.py

示例9: run

# 需要導入模塊: from tensorflow.keras import backend [as 別名]
# 或者: from tensorflow.keras.backend import set_session [as 別名]
def run(project_dir, gpu_mon, logger, args):
    """
    Runs training of a model in a mpunet project directory.

    Args:
        project_dir: A path to a mpunet project
        gpu_mon: An initialized GPUMonitor object
        logger: A mpunet logging object
        args: argparse arguments
    """
    # Read in hyperparameters from YAML file
    from mpunet.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/train_hparams.yaml", logger=logger)
    validate_hparams(hparams)

    # Wait for PID to terminate before continuing?
    if args.wait_for:
        from mpunet.utils import await_PIDs
        await_PIDs(args.wait_for)

    # Prepare sequence generators and potential model specific hparam changes
    train, val, hparams = get_data_sequences(project_dir=project_dir,
                                             hparams=hparams,
                                             logger=logger,
                                             args=args)

    # Set GPU visibility and create model with MirroredStrategy
    set_gpu(gpu_mon, args)
    import tensorflow as tf
    with tf.distribute.MirroredStrategy().scope():
        model = get_model(project_dir=project_dir, train_seq=train,
                          hparams=hparams, logger=logger, args=args)

        # Get trainer and compile model
        from mpunet.train import Trainer
        trainer = Trainer(model, logger=logger)
        trainer.compile_model(n_classes=hparams["build"].get("n_classes"),
                              reduction=tf.keras.losses.Reduction.NONE,
                              **hparams["fit"])

    # Debug mode?
    if args.debug:
        from tensorflow.python import debug as tfdbg
        from tensorflow.keras import backend as K
        K.set_session(tfdbg.LocalCLIDebugWrapperSession(K.get_session()))

    # Fit the model
    _ = trainer.fit(train=train, val=val,
                    train_im_per_epoch=args.train_images_per_epoch,
                    val_im_per_epoch=args.val_images_per_epoch,
                    hparams=hparams, no_im=args.no_images, **hparams["fit"])
    save_final_weights(model, project_dir, logger) 
開發者ID:perslev,項目名稱:MultiPlanarUNet,代碼行數:54,代碼來源:train.py


注:本文中的tensorflow.keras.backend.set_session方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。