当前位置: 首页>>代码示例>>Python>>正文


Python v1.get_default_graph方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.get_default_graph方法的典型用法代码示例。如果您正苦于以下问题:Python v1.get_default_graph方法的具体用法?Python v1.get_default_graph怎么用?Python v1.get_default_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.get_default_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: underlying_variable

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def underlying_variable(t):
  """Find the underlying tf.Variable object.

  Args:
    t: a Tensor

  Returns:
    tf.Variable.
  """
  t = underlying_variable_ref(t)
  assert t is not None
  # make sure that the graph has a variable index and that it is up-to-date
  if not hasattr(tf.get_default_graph(), "var_index"):
    tf.get_default_graph().var_index = {}
  var_index = tf.get_default_graph().var_index
  for v in tf.global_variables()[len(var_index):]:
    var_index[v.name] = v
  return var_index[t.name] 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:20,代码来源:common_layers.py

示例2: testLossDecorated

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def testLossDecorated(self):
    self.BuildWithBatchNorm(True)
    self.AddRegularizer()
    # Create network regularizer with DummyDecorator op regularization.
    self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
        [self.conv3.op, self.conv4.op],
        gamma_threshold=0.45,
        regularizer_decorator=dummy_decorator.DummyDecorator,
        decorator_parameters={'scale': 0.5})

    all_convs = [
        o for o in tf.get_default_graph().get_operations() if o.type == 'Conv2D'
    ]
    total_reg_term = 1410376.375
    self.assertAllClose(total_reg_term * 0.5, self.GetLoss(all_convs))
    self.assertAllClose(total_reg_term * 0.5, self.GetLoss([])) 
开发者ID:google-research,项目名称:morph-net,代码行数:18,代码来源:flop_regularizer_test.py

示例3: test_group_lasso_conv3d

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_group_lasso_conv3d(self):
    shape = [3, 3, 3]
    video = tf.zeros([2, 3, 3, 3, 1])
    net = slim.conv3d(
        video,
        5,
        shape,
        padding='VALID',
        weights_initializer=tf.glorot_normal_initializer(),
        scope='vconv1')
    conv3d_op = tf.get_default_graph().get_operation_by_name('vconv1/Conv3D')
    conv3d_weights = conv3d_op.inputs[1]

    threshold = 0.09
    flop_reg = flop_regularizer.GroupLassoFlopsRegularizer([net.op],
                                                           threshold=threshold)
    norm = tf.sqrt(tf.reduce_mean(tf.square(conv3d_weights), [0, 1, 2, 3]))
    alive = tf.reduce_sum(tf.cast(norm > threshold, tf.float32))
    with self.session():
      flop_coeff = 2 * shape[0] * shape[1] * shape[2]
      tf.compat.v1.global_variables_initializer().run()
      self.assertAllClose(flop_reg.get_cost(), flop_coeff * alive)
      self.assertAllClose(flop_reg.get_regularization_term(),
                          flop_coeff * tf.reduce_sum(norm)) 
开发者ID:google-research,项目名称:morph-net,代码行数:26,代码来源:flop_regularizer_test.py

示例4: testShareParams

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def testShareParams(self):
    # Tests reuse option.
    first_outputs = 2
    alternate_num_outputs = 12
    parameterization = {'first/Conv2D': first_outputs}
    decorator = ops.ConfigurableOps(parameterization=parameterization)
    explicit = layers.conv2d(
        self.inputs, first_outputs, 3, scope='first')
    with arg_scope([layers.conv2d], reuse=True):
      decorated = decorator.conv2d(
          self.inputs,
          num_outputs=alternate_num_outputs,
          kernel_size=3,
          scope='first')
    with self.cached_session():
      tf.global_variables_initializer().run()
      # verifies that parameters are shared.
      self.assertAllClose(explicit.eval(), decorated.eval())
    conv_ops = sorted([
        op.name
        for op in tf.get_default_graph().get_operations()
        if op.type == 'Conv2D'
    ])
    self.assertAllEqual(['first/Conv2D', 'first_1/Conv2D'], conv_ops) 
开发者ID:google-research,项目名称:morph-net,代码行数:26,代码来源:configurable_ops_test.py

示例5: test_fused_batchnorm

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_fused_batchnorm(self, use_depthwise):
    use_keras = False
    image_height = 256
    image_width = 256
    depth_multiplier = 1
    pad_to_multiple = 1
    image_placeholder = tf.placeholder(tf.float32,
                                       [1, image_height, image_width, 3])
    feature_extractor = self._create_feature_extractor(
        depth_multiplier,
        pad_to_multiple,
        use_keras=use_keras,
        use_depthwise=use_depthwise)
    preprocessed_image = feature_extractor.preprocess(image_placeholder)
    _ = feature_extractor.extract_features(preprocessed_image)
    self.assertTrue(
        any('FusedBatchNorm' in op.type
            for op in tf.get_default_graph().get_operations())) 
开发者ID:tensorflow,项目名称:models,代码行数:20,代码来源:ssd_mobilenet_v2_fpn_feature_extractor_tf1_test.py

示例6: test_overwriting_activation_fn

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_overwriting_activation_fn(self):
    for architecture in ['resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152']:
      feature_extractor = self._build_feature_extractor(
          first_stage_features_stride=16,
          architecture=architecture,
          activation_fn=tf.nn.relu6)
      preprocessed_inputs = tf.random_uniform([4, 224, 224, 3],
                                              maxval=255,
                                              dtype=tf.float32)
      rpn_feature_map, _ = feature_extractor.extract_proposal_features(
          preprocessed_inputs, scope='TestStage1Scope')
      _ = feature_extractor.extract_box_classifier_features(
          rpn_feature_map, scope='TestStaget2Scope')
      conv_ops = [
          op for op in tf.get_default_graph().get_operations()
          if op.type == 'Relu6'
      ]
      op_names = [op.name for op in conv_ops]

      self.assertIsNotNone(conv_ops)
      self.assertIn('TestStage1Scope/resnet_v1_50/resnet_v1_50/conv1/Relu6',
                    op_names)
      self.assertIn(
          'TestStaget2Scope/resnet_v1_50/block4/unit_1/bottleneck_v1/conv1/Relu6',
          op_names) 
开发者ID:tensorflow,项目名称:models,代码行数:27,代码来源:faster_rcnn_resnet_v1_feature_extractor_tf1_test.py

示例7: testQuantizationBuilderSetsUpCorrectTrainArguments

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
    with mock.patch.object(
        contrib_quantize,
        'experimental_create_training_graph') as mock_quant_fn:
      with mock.patch.object(slim,
                             'summarize_collection') as mock_summarize_col:
        graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
        graph_rewriter_proto.quantization.delay = 10
        graph_rewriter_proto.quantization.weight_bits = 8
        graph_rewriter_proto.quantization.activation_bits = 8
        graph_rewrite_fn = graph_rewriter_builder.build(
            graph_rewriter_proto, is_training=True)
        graph_rewrite_fn()
        _, kwargs = mock_quant_fn.call_args
        self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
        self.assertEqual(kwargs['quant_delay'], 10)
        mock_summarize_col.assert_called_with('quant_vars') 
开发者ID:tensorflow,项目名称:models,代码行数:19,代码来源:graph_rewriter_builder_tf1_test.py

示例8: test_output_nodes_for_tflite

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_output_nodes_for_tflite(self):
    image_height = 64
    image_width = 64
    depth_multiplier = 1.0
    pad_to_multiple = 1
    image_placeholder = tf.placeholder(tf.float32,
                                       [1, image_height, image_width, 3])
    feature_extractor = self._create_feature_extractor(depth_multiplier,
                                                       pad_to_multiple)
    preprocessed_image = feature_extractor.preprocess(image_placeholder)
    _ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)

    tflite_nodes = [
        'raw_inputs/init_lstm_c',
        'raw_inputs/init_lstm_h',
        'raw_inputs/base_endpoint',
        'raw_outputs/lstm_c',
        'raw_outputs/lstm_h',
        'raw_outputs/base_endpoint_1',
        'raw_outputs/base_endpoint_2'
    ]
    ops_names = [op.name for op in tf.get_default_graph().get_operations()]
    for node in tflite_nodes:
      self.assertTrue(any(node in s for s in ops_names)) 
开发者ID:tensorflow,项目名称:models,代码行数:26,代码来源:lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py

示例9: test_fixed_concat_nodes

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def test_fixed_concat_nodes(self):
    image_height = 64
    image_width = 64
    depth_multiplier = 1.0
    pad_to_multiple = 1
    image_placeholder = tf.placeholder(tf.float32,
                                       [1, image_height, image_width, 3])
    feature_extractor = self._create_feature_extractor(
        depth_multiplier, pad_to_multiple, is_quantized=True)
    preprocessed_image = feature_extractor.preprocess(image_placeholder)
    _ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)

    concat_nodes = [
        'MobilenetV2_1/expanded_conv_16/project/Relu6',
        'MobilenetV2_2/expanded_conv_16/project/Relu6'
    ]
    ops_names = [op.name for op in tf.get_default_graph().get_operations()]
    for node in concat_nodes:
      self.assertTrue(any(node in s for s in ops_names)) 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:lstm_ssd_interleaved_mobilenet_v2_feature_extractor_test.py

示例10: run_benchmark

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def run_benchmark(bench_cnn, num_iters):
  """Runs the all-reduce benchmark.

  Args:
    bench_cnn: The BenchmarkCNN where params, the variable manager, and other
      attributes are obtained.
    num_iters: Number of iterations to do all-reduce for for.

  Raises:
    ValueError: Invalid params of bench_cnn.
  """
  if bench_cnn.params.variable_update != 'replicated':
    raise ValueError('--variable_update=replicated must be specified to use'
                     'the all-reduce benchmark')
  if bench_cnn.params.variable_consistency == 'relaxed':
    raise ValueError('--variable_consistency=relaxed is not supported')

  benchmark_op = build_graph(bench_cnn.raw_devices,
                             get_var_shapes(bench_cnn.model),
                             bench_cnn.variable_mgr, num_iters)
  init_ops = [
      tf.global_variables_initializer(),
      bench_cnn.variable_mgr.get_post_init_ops()
  ]
  loss_op = tf.no_op()

  if bench_cnn.graph_file:
    path, filename = os.path.split(bench_cnn.graph_file)
    as_text = filename.endswith('txt')
    log_fn('Writing GraphDef as %s to %s' % (
        'text' if as_text else 'binary', bench_cnn.graph_file))
    tf.train.write_graph(tf.get_default_graph().as_graph_def(add_shapes=True),
                         path, filename, as_text)

  run_graph(benchmark_op, bench_cnn, init_ops, loss_op)


# TODO(reedwm): Reduce redundancy with tf_cnn_benchmarks 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:40,代码来源:all_reduce_benchmark.py

示例11: find_ops

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def find_ops(optype):
  """Find ops of a given type in graphdef or a graph.

  Args:
    optype: operation type (e.g. Conv2D)
  Returns:
     List of operations.
  """
  gd = tf.get_default_graph()
  return [var for var in gd.get_operations() if var.type == optype] 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:12,代码来源:mobilenet_test.py

示例12: _run_eval

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def _run_eval(self):
    """Evaluate a model every self.params.eval_interval_secs.

    Returns:
      Dictionary containing eval statistics. Currently returns an empty
      dictionary.

    Raises:
      ValueError: If self.params.train_dir is unspecified.
    """
    if self.params.train_dir is None:
      raise ValueError('Trained model directory not specified')
    graph_info = self._build_eval_graph()
    saver = tf.train.Saver(self.variable_mgr.savable_variables())
    summary_writer = tf.summary.FileWriter(self.params.eval_dir,
                                           tf.get_default_graph())
    target = ''
    # TODO(huangyp): Check if checkpoints haven't updated for hours and abort.
    while True:
      with tf.Session(
          target=target, config=create_config_proto(self.params)) as sess:
        image_producer = None
        try:
          global_step = load_checkpoint(saver, sess, self.params.train_dir)
          image_producer = self._initialize_eval_graph(
              graph_info.enqueue_ops, graph_info.input_producer_op,
              graph_info.local_var_init_op_group, sess)
        except CheckpointNotFoundException:
          log_fn('Checkpoint not found in %s' % self.params.train_dir)
        else:  # Only executes if an exception was not thrown
          self._eval_once(sess, summary_writer, graph_info.fetches,
                          graph_info.summary_op, image_producer, global_step)
        if image_producer is not None:
          image_producer.done()
        if self.params.eval_interval_secs <= 0:
          break
        time.sleep(self.params.eval_interval_secs)
    return {} 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:40,代码来源:benchmark_cnn.py

示例13: remove_summaries

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def remove_summaries():
  """Remove summaries from the default graph."""
  g = tf.get_default_graph()
  key = tf.GraphKeys.SUMMARIES
  log_debug("Remove summaries %s" % str(g.get_collection(key)))
  del g.get_collection_ref(key)[:]
  assert not g.get_collection(key) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:9,代码来源:t2t_model.py

示例14: framework

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def framework(msg='err'):
  """Return framework module or dummy version."""
  del msg
  if is_tf2:
    return DummyModule(
        arg_scope=None,
        get_name_scope=lambda: tf.get_default_graph().get_name_scope(),
        name_scope=tf.name_scope,
        deprecated=deprecated,
        nest=tf.nest,
        argsort=tf.argsort)

  from tensorflow.contrib import framework as contrib_framework  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
  return contrib_framework 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:16,代码来源:contrib.py

示例15: _get_beta_accumulators

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_default_graph [as 别名]
def _get_beta_accumulators(self):
    with tf.init_scope():
      if tf.executing_eagerly():
        graph = None
      else:
        graph = tf.get_default_graph()
      return (self._get_non_slot_variable("beta1_power", graph=graph),
              self._get_non_slot_variable("beta2_power", graph=graph)) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:10,代码来源:multistep_with_adamoptimizer.py


注:本文中的tensorflow.compat.v1.get_default_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。