当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.get_default_graph方法代码示例

本文整理汇总了Python中tensorflow.get_default_graph方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.get_default_graph方法的具体用法?Python tensorflow.get_default_graph怎么用?Python tensorflow.get_default_graph使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.get_default_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: build_from_pb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def build_from_pb(self):
		with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
		
		tf.import_graph_def(
			graph_def,
			name=""
		)
		with open(self.FLAGS.metaLoad, 'r') as fp:
			self.meta = json.load(fp)
		self.framework = create_framework(self.meta, self.FLAGS)

		# Placeholders
		self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
		self.feed = dict() # other placeholders
		self.out = tf.get_default_graph().get_tensor_by_name('output:0')
		
		self.setup_meta_ops() 
开发者ID:AmeyaWagh,项目名称:Traffic_sign_detection_YOLO,代码行数:21,代码来源:build.py

示例2: init_uninited_vars

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def init_uninited_vars(vars=None):
    if vars is None: vars = tf.global_variables()
    test_vars = []; test_ops = []
    with tf.control_dependencies(None): # ignore surrounding control_dependencies
        for var in vars:
            assert is_tf_expression(var)
            try:
                tf.get_default_graph().get_tensor_by_name(var.name.replace(':0', '/IsVariableInitialized:0'))
            except KeyError:
                # Op does not exist => variable may be uninitialized.
                test_vars.append(var)
                with absolute_name_scope(var.name.split(':')[0]):
                    test_ops.append(tf.is_variable_initialized(var))
    init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
    run([var.initializer for var in init_vars])

#----------------------------------------------------------------------------
# Set the values of given tf.Variables.
# Equivalent to the following, but more efficient and does not bloat the tf graph:
#   tfutil.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 
开发者ID:zalandoresearch,项目名称:disentangling_conditional_gans,代码行数:22,代码来源:tfutil.py

示例3: _py_func_with_gradient

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def _py_func_with_gradient(func, inp, Tout, stateful=True, name=None,
                           grad_func=None):
    """
    PyFunc defined as given by Tensorflow
    :param func: Custom Function
    :param inp: Function Inputs
    :param Tout: Ouput Type of out Custom Function
    :param stateful: Calculate Gradients when stateful is True
    :param name: Name of the PyFunction
    :param grad: Custom Gradient Function
    :return:
    """
    # Generate random name in order to avoid conflicts with inbuilt names
    rnd_name = 'PyFuncGrad-' + '%0x' % getrandbits(30 * 4)

    # Register Tensorflow Gradient
    tf.RegisterGradient(rnd_name)(grad_func)

    # Get current graph
    g = tf.get_default_graph()

    # Add gradient override map
    with g.gradient_override_map(
            {"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name) 
开发者ID:StephanZheng,项目名称:neural-fingerprinting,代码行数:27,代码来源:utils_pytorch.py

示例4: network_surgery

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def network_surgery():
    tf.reset_default_graph()
    inputs = tf.placeholder(tf.float32,
                            shape=(None, 131072, 4),
                            name='inputs')
    targets = tf.placeholder(tf.float32, shape=(None, 1024, 4229),
                             name='targets')
    targets_na = tf.placeholder(tf.bool, shape=(None, 1024), name="targets_na")
    preds_adhoc = tf.placeholder(tf.float32, shape=(None, 960, 4229), name="Placeholder_15")


    saver = tf.train.import_meta_graph("model_files/model.tf.meta",
                                       input_map={'Placeholder_15:0': preds_adhoc,
                                                  'Placeholder:0': targets_na,
                                                  'inputs:0': inputs,
                                                  'targets:0': targets
                                       })

    ops = tf.get_default_graph().get_operations()

    out = tf.train.export_meta_graph(filename='model_files/model.tf-modified.meta', as_text=True)

    ops[:15] 
开发者ID:kipoi,项目名称:models,代码行数:25,代码来源:test_model.py

示例5: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.gfile.MakeDirs(FLAGS.eval_dir)
  tf.logging.info('Building eval graph...')
  output = graphs.get_model().eval_graph(FLAGS.eval_data)
  eval_ops, moving_averaged_variables = output

  saver = tf.train.Saver(moving_averaged_variables)
  summary_writer = tf.summary.FileWriter(
      FLAGS.eval_dir, graph=tf.get_default_graph())

  while True:
    run_eval(eval_ops, summary_writer, saver)
    if FLAGS.run_once:
      break
    time.sleep(FLAGS.eval_interval_secs) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:18,代码来源:evaluate.py

示例6: count_weights

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def count_weights(scope=None, exclude=None, graph=None):
  """Count learnable parameters.

  Args:
    scope: Resrict the count to a variable scope.
    exclude: Regex to match variable names to exclude.
    graph: Operate on a graph other than the current default graph.

  Returns:
    Number of learnable parameters as integer.
  """
  if scope:
    scope = scope if scope.endswith('/') else scope + '/'
  graph = graph or tf.get_default_graph()
  vars_ = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  if scope:
    vars_ = [var for var in vars_ if var.name.startswith(scope)]
  if exclude:
    exclude = re.compile(exclude)
    vars_ = [var for var in vars_ if not exclude.match(var.name)]
  shapes = [var.get_shape().as_list() for var in vars_]
  return int(sum(np.prod(shape) for shape in shapes)) 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:24,代码来源:count_weights.py

示例7: build_from_pb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def build_from_pb(self):
        with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        tf.import_graph_def(
            graph_def,
            name=""
        )
        with open(self.FLAGS.metaLoad, 'r') as fp:
            self.meta = json.load(fp)
        self.framework = create_framework(self.meta, self.FLAGS)

        # Placeholders
        self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
        self.feed = dict()  # other placeholders
        self.out = tf.get_default_graph().get_tensor_by_name('output:0')

        self.setup_meta_ops() 
开发者ID:MahmudulAlam,项目名称:Automatic-Identification-and-Counting-of-Blood-Cells,代码行数:21,代码来源:build.py

示例8: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def __init__(self, checkpoint_filename, input_name="images",
                 output_name="features"):
        self.session = tf.Session()
        with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file_handle.read())
        tf.import_graph_def(graph_def, name="net")
        self.input_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % input_name)
        self.output_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % output_name)

        assert len(self.output_var.get_shape()) == 2
        assert len(self.input_var.get_shape()) == 4
        self.feature_dim = self.output_var.get_shape().as_list()[-1]
        self.image_shape = self.input_var.get_shape().as_list()[1:] 
开发者ID:nwojke,项目名称:deep_sort,代码行数:18,代码来源:generate_detections.py

示例9: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def main():
    args = parse_args()

    with tf.Session(graph=tf.Graph()) as session:
        input_var = tf.placeholder(
            tf.uint8, (None, 128, 64, 3), name="images")
        image_var = tf.map_fn(
            lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
            back_prop=False)

        factory_fn = _network_factory()
        features, _ = factory_fn(image_var, reuse=None)
        features = tf.identity(features, name="features")

        saver = tf.train.Saver(slim.get_variables_to_restore())
        saver.restore(session, args.checkpoint_in)

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            session, tf.get_default_graph().as_graph_def(),
            [features.name.split(":")[0]])
        with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
            file_handle.write(output_graph_def.SerializeToString()) 
开发者ID:nwojke,项目名称:deep_sort,代码行数:24,代码来源:freeze_model.py

示例10: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def __init__(self, train_df, word_count, batch_size, epochs):
        tf.set_random_seed(4)
        session_conf = tf.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=8)
        backend.set_session(tf.Session(graph=tf.get_default_graph(), config=session_conf))

        self.batch_size = batch_size
        self.epochs = epochs

        self.max_name_seq = 10
        self.max_item_desc_seq = 75
        self.max_text = word_count + 1
        self.max_brand = np.max(train_df.brand_name.max()) + 1
        self.max_condition = np.max(train_df.item_condition_id.max()) + 1
        self.max_subcat0 = np.max(train_df.subcat_0.max()) + 1
        self.max_subcat1 = np.max(train_df.subcat_1.max()) + 1
        self.max_subcat2 = np.max(train_df.subcat_2.max()) + 1 
开发者ID:aerdem4,项目名称:mercari-price-suggestion,代码行数:18,代码来源:nn_model.py

示例11: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def __init__(self, num_features, discriminator = discriminator, generator = generator_gatedcnn, mode = 'train', log_dir = './log'):

        self.num_features = num_features
        self.input_shape = [None, num_features, None] # [batch_size, num_features, num_frames]

        self.discriminator = discriminator
        self.generator = generator
        self.mode = mode

        self.build_model()
        self.optimizer_initializer()

        self.saver = tf.train.Saver()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

        if self.mode == 'train':
            self.train_step = 0
            now = datetime.now()
            self.log_dir = os.path.join(log_dir, now.strftime('%Y%m%d-%H%M%S'))
            self.writer = tf.summary.FileWriter(self.log_dir, tf.get_default_graph())
            self.generator_summaries, self.discriminator_summaries = self.summary() 
开发者ID:leimao,项目名称:Voice_Converter_CycleGAN,代码行数:24,代码来源:model.py

示例12: test_with_dynamic_shape

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def test_with_dynamic_shape(self):
    def fn(input_tensor):
      return tf.reduce_sum(input_tensor)
    input_tensor = tf.placeholder(tf.float32, shape=(None, 2))
    map_fn_output = shape_utils.static_or_dynamic_map_fn(fn, input_tensor)

    op_names = [op.name for op in tf.get_default_graph().get_operations()]
    self.assertTrue(any(['map' == op_name[:3] for op_name in op_names]))

    with self.test_session() as sess:
      result1 = sess.run(
          map_fn_output, feed_dict={
              input_tensor: [[1, 2], [3, 1], [0, 4]]})
      result2 = sess.run(
          map_fn_output, feed_dict={
              input_tensor: [[-1, 1], [0, 9]]})
      self.assertAllEqual(result1, [3, 4, 4])
      self.assertAllEqual(result2, [0, 9]) 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:20,代码来源:shape_utils_test.py

示例13: test_has_fused_batchnorm

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def test_has_fused_batchnorm(self, use_keras):
    image_height = 40
    image_width = 40
    depth_multiplier = 1
    pad_to_multiple = 1
    image_placeholder = tf.placeholder(tf.float32,
                                       [1, image_height, image_width, 3])
    feature_extractor = self._create_feature_extractor(depth_multiplier,
                                                       pad_to_multiple,
                                                       use_keras=use_keras)
    preprocessed_image = feature_extractor.preprocess(image_placeholder)
    if use_keras:
      _ = feature_extractor(preprocessed_image)
    else:
      _ = feature_extractor.extract_features(preprocessed_image)
    self.assertTrue(any(op.type == 'FusedBatchNorm'
                        for op in tf.get_default_graph().get_operations())) 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:19,代码来源:ssd_mobilenet_v2_feature_extractor_test.py

示例14: build

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def build(graph_rewriter_config, is_training):
  """Returns a function that modifies default graph based on options.

  Args:
    graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
    is_training: whether in training of eval mode.
  """
  def graph_rewrite_fn():
    """Function to quantize weights and activation of the default graph."""
    if (graph_rewriter_config.quantization.weight_bits != 8 or
        graph_rewriter_config.quantization.activation_bits != 8):
      raise ValueError('Only 8bit quantization is supported')

    # Quantize the graph by inserting quantize ops for weights and activations
    if is_training:
      tf.contrib.quantize.create_training_graph(
          input_graph=tf.get_default_graph(),
          quant_delay=graph_rewriter_config.quantization.delay)
    else:
      tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())

    tf.contrib.layers.summarize_collection('quant_vars')
  return graph_rewrite_fn 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:25,代码来源:graph_rewriter_builder.py

示例15: testQuantizationBuilderSetsUpCorrectTrainArguments

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_default_graph [as 别名]
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
    with mock.patch.object(
        tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
      with mock.patch.object(tf.contrib.layers,
                             'summarize_collection') as mock_summarize_col:
        graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
        graph_rewriter_proto.quantization.delay = 10
        graph_rewriter_proto.quantization.weight_bits = 8
        graph_rewriter_proto.quantization.activation_bits = 8
        graph_rewrite_fn = graph_rewriter_builder.build(
            graph_rewriter_proto, is_training=True)
        graph_rewrite_fn()
        _, kwargs = mock_quant_fn.call_args
        self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
        self.assertEqual(kwargs['quant_delay'], 10)
        mock_summarize_col.assert_called_with('quant_vars') 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:18,代码来源:graph_rewriter_builder_test.py


注:本文中的tensorflow.get_default_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。