当前位置: 首页>>代码示例>>Python>>正文


Python v1.ConfigProto方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.ConfigProto方法的典型用法代码示例。如果您正苦于以下问题:Python v1.ConfigProto方法的具体用法?Python v1.ConfigProto怎么用?Python v1.ConfigProto使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.ConfigProto方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testTrainWithSessionConfig

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def testTrainWithSessionConfig(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      tf_predictions = LogisticClassifier(tf_inputs)
      loss_ops.log_loss(tf_labels, tf_predictions)
      total_loss = loss_ops.get_total_loss()

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = learning.create_train_op(total_loss, optimizer)

      session_config = tf.ConfigProto(allow_soft_placement=True)
      loss = learning.train(
          train_op,
          None,
          number_of_steps=300,
          log_every_n_steps=10,
          session_config=session_config)
    self.assertIsNotNone(loss)
    self.assertLess(loss, .015) 
开发者ID:google-research,项目名称:tf-slim,代码行数:25,代码来源:learning_test.py

示例2: __init__

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def __init__(self, net_name, do_tf_logging=True, flagCS = False):
        set_tf_logging(do_tf_logging)
                
        # from https://kobkrit.com/using-allow-growth-memory-option-in-tensorflow-and-keras-dc8c8081bc96        
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
        #config.log_device_placement = True     # to log device placement (on which device the operation ran)
        sess = tf.Session(config=config)
        set_session(sess)  # set this TensorFlow session as the default session for Keras
        
        model, model_cen, pix_mean, pix_mean_cen = build_L2_net(net_name)
        self.flagCS = flagCS
        self.model = model
        self.model_cen = model_cen
        self.pix_mean = pix_mean
        self.pix_mean_cen = pix_mean_cen 
开发者ID:luigifreda,项目名称:pyslam,代码行数:18,代码来源:L2_Net.py

示例3: __init__

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def __init__(self, k: int, lu: float = 0.01, lv: float = 0.01, a: float = 1, b: float = 0.01) -> None:
        self.__sn = 'wmf'
        self.k = k
        self.lu = lu
        self.lv = lv
        self.a  = a
        self.b  = b
        self.tf_config = tf.ConfigProto()
        self.tf_config.gpu_options.allow_growth=True
        self.uids = None
        self.n_users = None
        self.usm = None
        self.iids = None
        self.n_items = None
        self.ism = None
        self.n_ratings = None
        self.u_rated = None
        self.i_rated = None
        self.fue = None
        self.fie = None 
开发者ID:domainxz,项目名称:top-k-rec,代码行数:22,代码来源:wmf.py

示例4: set_gpu_fraction

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def set_gpu_fraction(sess=None, gpu_fraction=0.3):
    """Set the GPU memory fraction for the application.

    Parameters
    ----------
    sess : a session instance of TensorFlow
        TensorFlow session
    gpu_fraction : a float
        Fraction of GPU memory, (0 ~ 1]

    References
    ----------
    - `TensorFlow using GPU <https://www.tensorflow.org/versions/r0.9/how_tos/using_gpu/index.html>`_
    """
    print("  tensorlayer: GPU MEM Fraction %f" % gpu_fraction)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
    sess = tf.Session(config = tf.ConfigProto(gpu_options = gpu_options))
    return sess 
开发者ID:ravisvi,项目名称:super-resolution-videos,代码行数:20,代码来源:ops.py

示例5: __init__

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def __init__(self, player_config, env_config):
    player_base.PlayerBase.__init__(self, player_config)

    self._action_set = (env_config['action_set']
                        if 'action_set' in env_config else 'default')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    self._sess = tf.Session(config=config)
    self._player_prefix = 'player_{}'.format(player_config['index'])
    stacking = 4 if player_config.get('stacked', True) else 1
    policy = player_config.get('policy', 'cnn')
    self._stacker = ObservationStacker(stacking)
    with tf.variable_scope(self._player_prefix):
      with tf.variable_scope('ppo2_model'):
        policy_fn = build_policy(DummyEnv(self._action_set, stacking), policy)
        self._policy = policy_fn(nbatch=1, sess=self._sess)
    _load_variables(player_config['checkpoint'], self._sess,
                    prefix=self._player_prefix + '/') 
开发者ID:google-research,项目名称:football,代码行数:20,代码来源:ppo2_cnn.py

示例6: initialize_session

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def initialize_session(self):
    """Initializes a tf.Session."""
    if ENABLE_TF_OPTIMIZATIONS:
      self.sess = tf.Session()
    else:
      session_config = tf.ConfigProto()
      rewrite_options = session_config.graph_options.rewrite_options
      rewrite_options.disable_model_pruning = True
      rewrite_options.constant_folding = rewrite_options.OFF
      rewrite_options.arithmetic_optimization = rewrite_options.OFF
      rewrite_options.remapping = rewrite_options.OFF
      rewrite_options.shape_optimization = rewrite_options.OFF
      rewrite_options.dependency_optimization = rewrite_options.OFF
      rewrite_options.function_optimization = rewrite_options.OFF
      rewrite_options.layout_optimizer = rewrite_options.OFF
      rewrite_options.loop_optimization = rewrite_options.OFF
      rewrite_options.memory_optimization = rewrite_options.NO_MEM_OPT
      self.sess = tf.Session(config=session_config)

    # Restore or initialize the variables.
    self.sess.run(tf.global_variables_initializer())
    self.sess.run(tf.local_variables_initializer()) 
开发者ID:google-research,项目名称:meta-dataset,代码行数:24,代码来源:trainer.py

示例7: create_session_config

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def create_session_config(log_device_placement=False,
                          enable_graph_rewriter=False,
                          gpu_mem_fraction=0.95,
                          use_tpu=False,
                          xla_jit_level=tf.OptimizerOptions.OFF,
                          inter_op_parallelism_threads=0,
                          intra_op_parallelism_threads=0):
  """The TensorFlow Session config to use."""
  if use_tpu:
    graph_options = tf.GraphOptions()
  else:
    if enable_graph_rewriter:
      rewrite_options = rewriter_config_pb2.RewriterConfig()
      rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.ON
      graph_options = tf.GraphOptions(rewrite_options=rewrite_options)
    else:
      graph_options = tf.GraphOptions(
          optimizer_options=tf.OptimizerOptions(
              opt_level=tf.OptimizerOptions.L1,
              do_function_inlining=False,
              global_jit_level=xla_jit_level))

  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_fraction)

  config = tf.ConfigProto(
      allow_soft_placement=True,
      graph_options=graph_options,
      gpu_options=gpu_options,
      log_device_placement=log_device_placement,
      inter_op_parallelism_threads=inter_op_parallelism_threads,
      intra_op_parallelism_threads=intra_op_parallelism_threads,
      isolate_session_state=True)
  return config 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:35,代码来源:trainer_lib.py

示例8: encode

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def encode(wav_data, checkpoint_path, sample_length=64000):
  """Generate an array of encodings from an array of audio.

  Args:
    wav_data: Numpy array [batch_size, sample_length]
    checkpoint_path: Location of the pretrained model.
    sample_length: The total length of the final wave file, padded with 0s.
  Returns:
    encoding: a [mb, 125, 16] encoding (for 64000 sample audio file).
  """
  if wav_data.ndim == 1:
    wav_data = np.expand_dims(wav_data, 0)
  batch_size = wav_data.shape[0]

  # Load up the model for encoding and find the encoding of "wav_data"
  session_config = tf.ConfigProto(allow_soft_placement=True)
  session_config.gpu_options.allow_growth = True
  with tf.Graph().as_default(), tf.Session(config=session_config) as sess:
    hop_length = Config().ae_hop_length
    wav_data, sample_length = utils.trim_for_encoding(wav_data, sample_length,
                                                      hop_length)
    net = load_nsynth(batch_size=batch_size, sample_length=sample_length)
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)
    encodings = sess.run(net["encoding"], feed_dict={net["X"]: wav_data})
  return encodings 
开发者ID:magenta,项目名称:magenta,代码行数:28,代码来源:fastgen.py

示例9: generate_session_config

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def generate_session_config() -> tf.ConfigProto:
    """
    Generate a ConfigProto to use for ML-Agents that doesn't consume all of the GPU memory
    and allows for soft placement in the case of multi-GPU.
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # For multi-GPU training, set allow_soft_placement to True to allow
    # placing the operation into an alternative device automatically
    # to prevent from exceptions if the device doesn't suppport the operation
    # or the device does not exist
    config.allow_soft_placement = True
    return config 
开发者ID:StepNeverStop,项目名称:RLs,代码行数:15,代码来源:tf.py

示例10: build_config

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def build_config(limit_gpu_fraction=0.2, limit_cpu_fraction=10):
    if limit_gpu_fraction > 0:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        gpu_options = GPUOptions(
            allow_growth=True,
            per_process_gpu_memory_fraction=limit_gpu_fraction)
        config = ConfigProto(gpu_options=gpu_options)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        config = ConfigProto(device_count={'GPU': 0})
    if limit_cpu_fraction is not None:
        if limit_cpu_fraction == 0:
            cpu_count = 1
        if limit_cpu_fraction < 0:
            # -2 gives all CPUs except 1
            cpu_count = max(
                1, int(os.cpu_count() + limit_cpu_fraction + 1))
        elif limit_cpu_fraction < 1:
            # 0.5 gives 50% of available CPUs
            cpu_count = max(
                1, int(os.cpu_count() * limit_cpu_fraction))
        else:
            # 2 gives 2 CPUs
            cpu_count = int(limit_cpu_fraction)
        config.inter_op_parallelism_threads = cpu_count
        config.intra_op_parallelism_threads = cpu_count
        os.environ['OMP_NUM_THREADS'] = str(1)
        os.environ['MKL_NUM_THREADS'] = str(cpu_count)
    return config 
开发者ID:scottgigante,项目名称:m-phate,代码行数:31,代码来源:train.py

示例11: get_session_config

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def get_session_config(self):
    """Get the Session tf.ConfigProto for Estimator model.""" 
开发者ID:google-research,项目名称:tensor2robot,代码行数:4,代码来源:model_interface.py

示例12: start_tf_sess

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def start_tf_sess():
    """
    Returns a tf.Session w/ config
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config) 
开发者ID:re-search,项目名称:gpt2-estimator,代码行数:9,代码来源:gpt_2.py

示例13: make_session

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def make_session(config=None, num_cpu=None, make_default=False, graph=None):
    """Returns a session that will use <num_cpu> CPU's only"""
    if num_cpu is None:
        num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
    if config is None:
        config = tf.ConfigProto(
            allow_soft_placement=True,
            inter_op_parallelism_threads=num_cpu,
            intra_op_parallelism_threads=num_cpu)
        config.gpu_options.allow_growth = True

    if make_default:
        return tf.InteractiveSession(config=config, graph=graph)
    else:
        return tf.Session(config=config, graph=graph) 
开发者ID:microsoft,项目名称:nni,代码行数:17,代码来源:util.py

示例14: _build_session

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def _build_session(self):
    sess_config = tf.ConfigProto()
    if self.use_xla:
      sess_config.graph_options.optimizer_options.global_jit_level = (
          tf.OptimizerOptions.ON_2)
    return tf.Session(config=sess_config) 
开发者ID:PINTO0309,项目名称:PINTO_model_zoo,代码行数:8,代码来源:inference.py

示例15: _lazily_initialize

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import ConfigProto [as 别名]
def _lazily_initialize(self):
        """Initialize the graph and session, if this has not yet been done."""
        # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
        import tensorflow.compat.v1 as tf

        with self._initialization_lock:
            if self._session:
                return
            graph = tf.Graph()
            with graph.as_default():
                self.initialize_graph()
            # Don't reserve GPU because libpng can't run on GPU.
            config = tf.ConfigProto(device_count={"GPU": 0})
            self._session = tf.Session(graph=graph, config=config) 
开发者ID:tensorflow,项目名称:tensorboard,代码行数:16,代码来源:op_evaluator.py


注:本文中的tensorflow.compat.v1.ConfigProto方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。