当前位置: 首页>>代码示例>>Python>>正文


Python nets_factory.get_network_fn方法代码示例

本文整理汇总了Python中nets.nets_factory.get_network_fn方法的典型用法代码示例。如果您正苦于以下问题:Python nets_factory.get_network_fn方法的具体用法?Python nets_factory.get_network_fn怎么用?Python nets_factory.get_network_fn使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在nets.nets_factory的用法示例。


在下文中一共展示了nets_factory.get_network_fn方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    if hasattr(network_fn, 'default_image_size'):
      image_size = network_fn.default_image_size
    else:
      image_size = FLAGS.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:23,代码来源:export_inference_graph.py

示例2: main

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:yuantailing,项目名称:ctw-baseline,代码行数:21,代码来源:export_inference_graph.py

示例3: main

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)

    if FLAGS.quantize:
      tf.contrib.quantize.create_eval_graph()

    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:andrewekhalel,项目名称:edafa,代码行数:25,代码来源:export_inference_graph.py

示例4: testGetNetworkFnFirstHalf

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnFirstHalf(self):
    batch_size = 5
    num_classes = 1000
    for net in list(nets_factory.networks_map.keys())[:10]:
      with tf.Graph().as_default() as g, self.test_session(g):
        net_fn = nets_factory.get_network_fn(net, num_classes=num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        if net not in ['i3d', 's3dg']:
          inputs = tf.random_uniform(
              (batch_size, image_size, image_size, 3))
          logits, end_points = net_fn(inputs)
          self.assertTrue(isinstance(logits, tf.Tensor))
          self.assertTrue(isinstance(end_points, dict))
          self.assertEqual(logits.get_shape().as_list()[0], batch_size)
          self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:google-research,项目名称:morph-net,代码行数:18,代码来源:nets_factory_test.py

示例5: testGetNetworkFnSecondHalf

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnSecondHalf(self):
    batch_size = 5
    num_classes = 1000
    for net in list(nets_factory.networks_map.keys())[10:]:
      with tf.Graph().as_default() as g, self.test_session(g):
        net_fn = nets_factory.get_network_fn(net, num_classes=num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        if net not in ['i3d', 's3dg']:
          inputs = tf.random_uniform(
              (batch_size, image_size, image_size, 3))
          logits, end_points = net_fn(inputs)
          self.assertTrue(isinstance(logits, tf.Tensor))
          self.assertTrue(isinstance(end_points, dict))
          self.assertEqual(logits.get_shape().as_list()[0], batch_size)
          self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:google-research,项目名称:morph-net,代码行数:18,代码来源:nets_factory_test.py

示例6: main

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    if hasattr(network_fn, 'default_image_size'):
      image_size = network_fn.default_image_size
    else:
      image_size = FLAGS.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:anthonyhu,项目名称:tumblr-emotions,代码行数:23,代码来源:export_inference_graph.py

示例7: _build_evaluate_model

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def _build_evaluate_model(self):
        self.input_image = tf.placeholder(tf.float32, shape=[None, None, 3])
        self.style_image = tf.placeholder(tf.float32, shape=[None, None, 3])
        preprocess_fn = preprocessing_factory.get_preprocessing(self.config.net_name, is_training=False)

        height = self.evaluate_height if self.evaluate_height else self.PREPROCESS_SIZE
        width = self.evaluate_width if self.evaluate_width else self.PREPROCESS_SIZE

        preprocessed_image = preprocess_fn(self.input_image, height, width, resize_side_min=min(height, width))
        images = tf.expand_dims(preprocessed_image, axis=0)

        style_images = tf.expand_dims(preprocess_fn(self.style_image, self.PREPROCESS_SIZE, self.PREPROCESS_SIZE), axis=0)

        self.swaped_tensor = self._swap_net(images, style_images)

        #
        # network_fn = nets_factory.get_network_fn(self.config.net_name, num_classes=1, is_training=False)
        # _, endpoints_dict = network_fn(images, spatial_squeeze=False)
        # self.swaped_tensor = endpoints_dict[self.config.net_name + self.style_layer]

        self.generated = self._inverse_net(self.swaped_tensor)

        self.evaluate_op = tf.squeeze(self.generated, axis=0)
        self.init_op = self._get_network_init_fn()
        self.save_variables = [var for var in tf.trainable_variables() if var.name.startswith("inverse_net")] 
开发者ID:benbenlijie,项目名称:style_swap_tensorflow,代码行数:27,代码来源:style_swap_model.py

示例8: _train_inverse

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def _train_inverse(self, generated, swaped_tensor):
        preprocess_fn = preprocessing_factory.get_preprocessing(self.config.net_name, is_training=False)
        network_fn = nets_factory.get_network_fn(self.config.net_name, num_classes=1, is_training=False)
        with tf.variable_scope("", reuse=True):
            preprocessed_image = tf.stack([preprocess_fn(img, self.PREPROCESS_SIZE, self.PREPROCESS_SIZE)
                                           for img in tf.unstack(generated, axis=0)])
            _, inversed_endpoints_dict = network_fn(preprocessed_image, spatial_squeeze=False)
            layer_names = list(inversed_endpoints_dict.keys())
            [layer_name] = [l_name for l_name in layer_names if self.style_layer in l_name]
            inversed_style_layer = inversed_endpoints_dict[layer_name]
        # print(inversed_style_layer.get_shape())
        tf.losses.add_loss(tf.nn.l2_loss(swaped_tensor - inversed_style_layer))
        self.loss_op = tf.losses.get_total_loss()

        train_vars = [var for var in tf.trainable_variables() if var.name.startswith("inverse_net")]
        slim.summarize_tensor(self.loss_op, "loss")
        slim.summarize_tensors(train_vars)
        # print(train_vars)
        self.save_variables = train_vars

        learning_rate = tf.train.exponential_decay(self.config.learning_rate, self.global_step, 1000, 0.66,
                                                   name="learning_rate")
        self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss_op, self.global_step, train_vars) 
开发者ID:benbenlijie,项目名称:style_swap_tensorflow,代码行数:25,代码来源:style_swap_model.py

示例9: prepare_inception_score_classifier

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def prepare_inception_score_classifier(classifier_name, num_classes, images, return_saver=True):
    network_fn = nets_factory.get_network_fn(
      classifier_name,
      num_classes=num_classes,
      weight_decay=0.0,
      is_training=False,
    )
    # Note: you may need to change the prediction_fn here.
    try:
      logits, end_points = network_fn(images, prediction_fn=tf.sigmoid, create_aux_logits=False)
    except TypeError:
      tf.logging.warning('Cannot specify prediction_fn=tf.sigmoid, create_aux_logits=False.')
      logits, end_points = network_fn(images, )

    variables_to_restore = slim.get_model_variables(scope=nets_factory.scopes_map[classifier_name])
    predictions = end_points['Predictions']
    if return_saver:
      saver = tf.train.Saver(variables_to_restore)
      return predictions, end_points, saver
    else:
      return predictions, end_points 
开发者ID:jerryli27,项目名称:TwinGAN,代码行数:23,代码来源:image_generation.py

示例10: testGetNetworkFn

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFn(self):
    batch_size = 5
    num_classes = 1000
    for net in nets_factory.networks_map:
      with self.test_session():
        net_fn = nets_factory.get_network_fn(net, num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
        logits, end_points = net_fn(inputs)
        self.assertTrue(isinstance(logits, tf.Tensor))
        self.assertTrue(isinstance(end_points, dict))
        self.assertEqual(logits.get_shape().as_list()[0], batch_size)
        self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:16,代码来源:nets_factory_test.py

示例11: testGetNetworkFnArgScope

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnArgScope(self):
    batch_size = 5
    num_classes = 10
    net = 'cifarnet'
    with self.test_session(use_gpu=True):
      net_fn = nets_factory.get_network_fn(net, num_classes)
      image_size = getattr(net_fn, 'default_image_size', 224)
      with slim.arg_scope([slim.model_variable, slim.variable],
                          device='/CPU:0'):
        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
        net_fn(inputs)
      weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0]
      self.assertDeviceEqual('/CPU:0', weights.device) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:15,代码来源:nets_factory_test.py

示例12: testGetNetworkFnFirstHalf

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnFirstHalf(self):
    batch_size = 5
    num_classes = 1000
    for net in list(nets_factory.networks_map.keys())[:10]:
      with tf.Graph().as_default() as g, self.test_session(g):
        net_fn = nets_factory.get_network_fn(net, num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
        logits, end_points = net_fn(inputs)
        self.assertTrue(isinstance(logits, tf.Tensor))
        self.assertTrue(isinstance(end_points, dict))
        self.assertEqual(logits.get_shape().as_list()[0], batch_size)
        self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:leimao,项目名称:DeepLab_v3,代码行数:16,代码来源:nets_factory_test.py

示例13: testGetNetworkFnSecondHalf

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnSecondHalf(self):
    batch_size = 5
    num_classes = 1000
    for net in list(nets_factory.networks_map.keys())[10:]:
      with tf.Graph().as_default() as g, self.test_session(g):
        net_fn = nets_factory.get_network_fn(net, num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
        logits, end_points = net_fn(inputs)
        self.assertTrue(isinstance(logits, tf.Tensor))
        self.assertTrue(isinstance(end_points, dict))
        self.assertEqual(logits.get_shape().as_list()[0], batch_size)
        self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:leimao,项目名称:DeepLab_v3,代码行数:16,代码来源:nets_factory_test.py

示例14: testGetNetworkFnFirstHalf

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnFirstHalf(self):
    batch_size = 5
    num_classes = 1000
    for net in nets_factory.networks_map.keys()[:10]:
      with tf.Graph().as_default() as g, self.test_session(g):
        net_fn = nets_factory.get_network_fn(net, num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
        logits, end_points = net_fn(inputs)
        self.assertTrue(isinstance(logits, tf.Tensor))
        self.assertTrue(isinstance(end_points, dict))
        self.assertEqual(logits.get_shape().as_list()[0], batch_size)
        self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:yuantailing,项目名称:ctw-baseline,代码行数:16,代码来源:nets_factory_test.py

示例15: testGetNetworkFnSecondHalf

# 需要导入模块: from nets import nets_factory [as 别名]
# 或者: from nets.nets_factory import get_network_fn [as 别名]
def testGetNetworkFnSecondHalf(self):
    batch_size = 5
    num_classes = 1000
    for net in nets_factory.networks_map.keys()[10:]:
      with tf.Graph().as_default() as g, self.test_session(g):
        net_fn = nets_factory.get_network_fn(net, num_classes)
        # Most networks use 224 as their default_image_size
        image_size = getattr(net_fn, 'default_image_size', 224)
        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
        logits, end_points = net_fn(inputs)
        self.assertTrue(isinstance(logits, tf.Tensor))
        self.assertTrue(isinstance(end_points, dict))
        self.assertEqual(logits.get_shape().as_list()[0], batch_size)
        self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 
开发者ID:yuantailing,项目名称:ctw-baseline,代码行数:16,代码来源:nets_factory_test.py


注:本文中的nets.nets_factory.get_network_fn方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。