当前位置: 首页>>代码示例>>Python>>正文


Python backend.set_session方法代码示例

本文整理汇总了Python中tensorflow.keras.backend.set_session方法的典型用法代码示例。如果您正苦于以下问题:Python backend.set_session方法的具体用法?Python backend.set_session怎么用?Python backend.set_session使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.keras.backend的用法示例。


在下文中一共展示了backend.set_session方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            inp = None
            output = None
            if self.shared_network is None:
                inp = Input((self.input_dim,))
                output = self.get_network_head(inp).output
            else:
                inp = self.shared_network.input
                output = self.shared_network.output
            output = Dense(
                self.output_dim, activation=self.activation, 
                kernel_initializer='random_normal')(output)
            self.model = Model(inp, output)
            self.model.compile(
                optimizer=SGD(lr=self.lr), loss=self.loss) 
开发者ID:quantylab,项目名称:rltrader,代码行数:21,代码来源:networks.py

示例2: cpu_config

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def cpu_config(first=False):
    # intel optimizations
    num_cores, num_sockets = get_cpuinfo()
    if first:
        print("system info::")
        print("Number of physical cores:: ", num_cores)
        print("Number of sockets::", num_sockets)
    backend.set_session(
        tf.Session(
            config=tf.ConfigProto(
                intra_op_parallelism_threads=num_cores,
                inter_op_parallelism_threads=num_sockets,
            )
        )
    )
###########################################################
# Training
########################################################### 
开发者ID:intel,项目名称:stacks-usecase,代码行数:20,代码来源:main.py

示例3: set_session

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def set_session(sess): pass 
开发者ID:quantylab,项目名称:rltrader,代码行数:3,代码来源:networks.py

示例4: predict

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def predict(self, sample):
        with self.lock:
            with graph.as_default():
                if sess is not None:
                    set_session(sess)
                return self.model.predict(sample).flatten() 
开发者ID:quantylab,项目名称:rltrader,代码行数:8,代码来源:networks.py

示例5: train_on_batch

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def train_on_batch(self, x, y):
        loss = 0.
        with self.lock:
            with graph.as_default():
                if sess is not None:
                    set_session(sess)
                loss = self.model.train_on_batch(x, y)
        return loss 
开发者ID:quantylab,项目名称:rltrader,代码行数:10,代码来源:networks.py

示例6: get_shared_network

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def get_shared_network(cls, net='dnn', num_steps=1, input_dim=0):
        with graph.as_default():
            if sess is not None:
                set_session(sess)
            if net == 'dnn':
                return DNN.get_network_head(Input((input_dim,)))
            elif net == 'lstm':
                return LSTMNetwork.get_network_head(
                    Input((num_steps, input_dim)))
            elif net == 'cnn':
                return CNN.get_network_head(
                    Input((1, num_steps, input_dim))) 
开发者ID:quantylab,项目名称:rltrader,代码行数:14,代码来源:networks.py

示例7: _run

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def _run(FLAGS):
  hparams = init_hparams(FLAGS)
  init_random_seeds(hparams)

  for run in range(hparams.copies):
    log_start_of_run(FLAGS, hparams, run)

    with tf.Session() as sess:
      K.set_session(sess)
      agent, checkpoint = init_agent(sess, hparams)

      restored = checkpoint.restore()
      if not restored:
        sess.run(tf.global_variables_initializer())

      if not hparams.test_only:
        log_graph()

        agent.clone_weights()

        if hparams.num_workers == 1:
          train(0, agent, hparams, checkpoint)
        else:
          workers = [
              threading.Thread(
                  target=train, args=(worker_id, agent, hparams, checkpoint))
              for worker_id in range(hparams.num_workers)
          ]

          for worker in workers:
            worker.start()

          for worker in workers:
            worker.join()
      else:
        test(hparams, agent)

    hparams = init_hparams(FLAGS) 
开发者ID:for-ai,项目名称:rl,代码行数:40,代码来源:train.py

示例8: configure_gpus

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def configure_gpus(gpus):
    # set gpu id and tf settings
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(g) for g in gpus])

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    K.set_session(tf.Session(config=config))


# loads a saved experiment using the saved parameters.
# runs all initialization steps so that we can use the models right away 
开发者ID:xamyzhao,项目名称:brainstorm,代码行数:14,代码来源:experiment_engine.py

示例9: run

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import set_session [as 别名]
def run(project_dir, gpu_mon, logger, args):
    """
    Runs training of a model in a mpunet project directory.

    Args:
        project_dir: A path to a mpunet project
        gpu_mon: An initialized GPUMonitor object
        logger: A mpunet logging object
        args: argparse arguments
    """
    # Read in hyperparameters from YAML file
    from mpunet.hyperparameters import YAMLHParams
    hparams = YAMLHParams(project_dir + "/train_hparams.yaml", logger=logger)
    validate_hparams(hparams)

    # Wait for PID to terminate before continuing?
    if args.wait_for:
        from mpunet.utils import await_PIDs
        await_PIDs(args.wait_for)

    # Prepare sequence generators and potential model specific hparam changes
    train, val, hparams = get_data_sequences(project_dir=project_dir,
                                             hparams=hparams,
                                             logger=logger,
                                             args=args)

    # Set GPU visibility and create model with MirroredStrategy
    set_gpu(gpu_mon, args)
    import tensorflow as tf
    with tf.distribute.MirroredStrategy().scope():
        model = get_model(project_dir=project_dir, train_seq=train,
                          hparams=hparams, logger=logger, args=args)

        # Get trainer and compile model
        from mpunet.train import Trainer
        trainer = Trainer(model, logger=logger)
        trainer.compile_model(n_classes=hparams["build"].get("n_classes"),
                              reduction=tf.keras.losses.Reduction.NONE,
                              **hparams["fit"])

    # Debug mode?
    if args.debug:
        from tensorflow.python import debug as tfdbg
        from tensorflow.keras import backend as K
        K.set_session(tfdbg.LocalCLIDebugWrapperSession(K.get_session()))

    # Fit the model
    _ = trainer.fit(train=train, val=val,
                    train_im_per_epoch=args.train_images_per_epoch,
                    val_im_per_epoch=args.val_images_per_epoch,
                    hparams=hparams, no_im=args.no_images, **hparams["fit"])
    save_final_weights(model, project_dir, logger) 
开发者ID:perslev,项目名称:MultiPlanarUNet,代码行数:54,代码来源:train.py


注:本文中的tensorflow.keras.backend.set_session方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。