当前位置: 首页>>代码示例>>Python>>正文


Python v1.reset_default_graph方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.reset_default_graph方法的典型用法代码示例。如果您正苦于以下问题:Python v1.reset_default_graph方法的具体用法?Python v1.reset_default_graph怎么用?Python v1.reset_default_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.reset_default_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testFlopRegularizerDontConvertToVariable

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testFlopRegularizerDontConvertToVariable(self):
    tf.reset_default_graph()
    tf.set_random_seed(1234)

    x = tf.constant(1.0, shape=[2, 6], name='x', dtype=tf.float32)
    w = tf.Variable(tf.truncated_normal([6, 4], stddev=1.0), use_resource=True)
    net = tf.matmul(x, w)

    # Create FLOPs network regularizer.
    threshold = 0.9
    flop_reg = flop_regularizer.GroupLassoFlopsRegularizer([net.op], threshold,
                                                           0)

    with self.cached_session():
      tf.global_variables_initializer().run()
      flop_reg.get_regularization_term().eval() 
开发者ID:google-research,项目名称:morph-net,代码行数:18,代码来源:flop_regularizer_test.py

示例2: testCreateDropoutWithPlaceholder

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testCreateDropoutWithPlaceholder(self):
    height, width = 3, 3
    tf.reset_default_graph()
    with self.cached_session():
      is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[])
      images = random_ops.random_uniform((5, height, width, 3), seed=1)
      # this verifies that that we've inserted cond properly.
      output = _layers.dropout(images, is_training=is_training)
      # In control_flow_v2 the op is called "If" and it is behind
      # identity op. In legacy mode cond we just go by name.
      # Might need to do something more robust here eventually.
      is_cond_op = (output.op.inputs[0].op.type == 'If' or
                    output.op.name == 'Dropout/cond/Merge')
      self.assertTrue(is_cond_op,
                      'Expected cond_op got ' + repr(output))
      output.get_shape().assert_is_compatible_with(images.get_shape()) 
开发者ID:google-research,项目名称:tf-slim,代码行数:18,代码来源:layers_test.py

示例3: init_data_normalizer

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def init_data_normalizer(config):
  """Initializes data normalizer."""
  normalizer = data_normalizer.registry[config['data_normalizer']](config)
  if normalizer.exists():
    return

  if config['task'] == 0:
    tf.reset_default_graph()
    data_helper = data_helpers.registry[config['data_type']](config)
    real_images, _ = data_helper.provide_data(batch_size=10)

    # Save normalizer.
    # Note if normalizer has been saved, save() is no-op. To regenerate the
    # normalizer, delete the normalizer file in train_root_dir/assets
    normalizer.save(real_images)
  else:
    while not normalizer.exists():
      time.sleep(5) 
开发者ID:magenta,项目名称:magenta,代码行数:20,代码来源:gansynth_train.py

示例4: run

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def run(config):
  """Entry point to run training."""
  init_data_normalizer(config)

  stage_ids = train_util.get_stage_ids(**config)
  if not config['train_progressive']:
    stage_ids = list(stage_ids)[-1:]

  # Train one stage at a time
  for stage_id in stage_ids:
    batch_size = train_util.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(config['ps_tasks'])):
      model = lib_model.Model(stage_id, batch_size, config)
      model.add_summaries()
      print('Variables:')
      for v in tf.global_variables():
        print('\t', v.name, v.get_shape().as_list())
      logging.info('Calling train.train')
      train_util.train(model, **config) 
开发者ID:magenta,项目名称:magenta,代码行数:22,代码来源:gansynth_train.py

示例5: load_entities

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def load_entities(self, base_dir):
    """Load entity ids and masks."""
    tf.reset_default_graph()
    id_ckpt = os.path.join(base_dir, "entity_ids")
    entity_ids = search_utils.load_database(
        "entity_ids", None, id_ckpt, dtype=tf.int32)
    mask_ckpt = os.path.join(base_dir, "entity_mask")
    entity_mask = search_utils.load_database(
        "entity_mask", None, mask_ckpt)
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      tf.logging.info("Loading entity ids and masks...")
      np_ent_ids, np_ent_mask = sess.run([entity_ids, entity_mask])
    tf.logging.info("Building entity count matrix...")
    entity_count_matrix = search_utils.build_count_matrix(np_ent_ids,
                                                          np_ent_mask)
    tf.logging.info("Computing IDFs...")
    self.idfs = search_utils.counts_to_idfs(entity_count_matrix, cutoff=1e-5)
    tf.logging.info("Computing entity Tf-IDFs...")
    ent_tfidfs = search_utils.counts_to_tfidf(entity_count_matrix, self.idfs)
    self.ent_tfidfs = normalize(ent_tfidfs, norm="l2", axis=0) 
开发者ID:google-research,项目名称:language,代码行数:24,代码来源:demo.py

示例6: parse

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def parse(self, onnx_file, output_nodes=None, model_name=None):
    tf.disable_eager_execution()
    if model_name:
      graph_name = model_name
    else:
      graph_name, _ = os.path.splitext(
        os.path.basename(onnx_file)
      )
    tf.reset_default_graph()
    model = onnx.load(onnx_file)
    onnx_graph = model.graph
    ugraph = uTensorGraph(
      name=graph_name,
      output_nodes=[],
      lib_name='onnx',
      ops_info={},
    )
    self._build_graph(onnx_graph, ugraph)
    ugraph = Legalizer.legalize(ugraph)
    tf.reset_default_graph()
    return ugraph 
开发者ID:uTensor,项目名称:utensor_cgen,代码行数:23,代码来源:onnx.py

示例7: test_generator_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def test_generator_graph(self):
    tf.set_random_seed(1234)
    # Check graph construction for a number of image size/depths and batch
    # sizes.
    for i, batch_size in zip(xrange(3, 7), xrange(3, 8)):
      tf.reset_default_graph()
      final_size = 2 ** i
      noise = tf.random.normal([batch_size, 64])
      image, end_points = dcgan.generator(
          noise,
          depth=32,
          final_size=final_size)

      self.assertAllEqual([batch_size, final_size, final_size, 3],
                          image.shape.as_list())

      expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits']
      self.assertSetEqual(set(expected_names), set(end_points.keys()))

      # Check layer depths.
      for j in range(1, i):
        layer = end_points['deconv%i' % j]
        self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 
开发者ID:tensorflow,项目名称:models,代码行数:25,代码来源:dcgan_test.py

示例8: test_discriminator_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def test_discriminator_graph(self):
    # Check graph construction for a number of image size/depths and batch
    # sizes.
    for i, batch_size in zip(xrange(1, 6), xrange(3, 8)):
      tf.reset_default_graph()
      img_w = 2 ** i
      image = tf.random.uniform([batch_size, img_w, img_w, 3], -1, 1)
      output, end_points = dcgan.discriminator(
          image,
          depth=32)

      self.assertAllEqual([batch_size, 1], output.get_shape().as_list())

      expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits']
      self.assertSetEqual(set(expected_names), set(end_points.keys()))

      # Check layer depths.
      for j in range(1, i+1):
        layer = end_points['conv%i' % j]
        self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 
开发者ID:tensorflow,项目名称:models,代码行数:22,代码来源:dcgan_test.py

示例9: testGlobalPoolUnknownImageShape

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testGlobalPoolUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 1
    height, width = 250, 300
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(
          tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = mobilenet_v1.mobilenet_v1(inputs, num_classes,
                                                     global_pool=True)
      self.assertTrue(logits.op.name.startswith('MobilenetV1/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Conv2d_13_pointwise']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 10, 1024]) 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:mobilenet_v1_test.py

示例10: testUnknownImageShape

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(
          tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v2(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024]) 
开发者ID:tensorflow,项目名称:models,代码行数:20,代码来源:inception_v2_test.py

示例11: testGlobalPoolUnknownImageShape

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testGlobalPoolUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 1
    height, width = 250, 300
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(
          tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v2(inputs, num_classes,
                                                  global_pool=True)
      self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 10, 1024]) 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:inception_v2_test.py

示例12: testUnknownImageShape

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 2
    height, width = 299, 299
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(
          tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v3(inputs, num_classes)
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_7c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048]) 
开发者ID:tensorflow,项目名称:models,代码行数:19,代码来源:inception_v3_test.py

示例13: testGlobalPoolUnknownImageShape

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testGlobalPoolUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 1
    height, width = 330, 400
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(
          tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v3(inputs, num_classes,
                                                  global_pool=True)
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_7c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 11, 2048]) 
开发者ID:tensorflow,项目名称:models,代码行数:20,代码来源:inception_v3_test.py

示例14: testGlobalPoolUnknownImageShape

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def testGlobalPoolUnknownImageShape(self):
    tf.reset_default_graph()
    batch_size = 1
    height, width = 250, 300
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = tf.placeholder(
          tf.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception.inception_v1(inputs, num_classes,
                                                  global_pool=True)
      self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      tf.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 10, 1024]) 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:inception_v1_test.py

示例15: _get_output_names

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import reset_default_graph [as 别名]
def _get_output_names(self):
        """Return the concatenated output names"""
        try:
            import tensorflow.compat.v1 as tf
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")
        tags = self._get_tag_set()
        output_names = set()
        with tf.Session() as sess:
            meta_graph_def = tf.saved_model.loader.load(sess,
                                                        tags,
                                                        self._model_dir)
            for sig_def in meta_graph_def.signature_def.values():
                for output_tensor in sig_def.outputs.values():
                    output_names.add(output_tensor.name.replace(":0", ""))
        tf.reset_default_graph()
        return ",".join(output_names) 
开发者ID:apache,项目名称:incubator-tvm,代码行数:21,代码来源:tensorflow_parser.py


注:本文中的tensorflow.compat.v1.reset_default_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。