当前位置: 首页>>代码示例>>Python>>正文


Python inputs.create_train_input_fn方法代码示例

本文整理汇总了Python中object_detection.inputs.create_train_input_fn方法的典型用法代码示例。如果您正苦于以下问题:Python inputs.create_train_input_fn方法的具体用法?Python inputs.create_train_input_fn怎么用?Python inputs.create_train_input_fn使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在object_detection.inputs的用法示例。


在下文中一共展示了inputs.create_train_input_fn方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: setUp

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def setUp(self):
    super(CheckpointV2Test, self).setUp()

    self._model = SimpleModel()
    tf.keras.backend.set_value(self._model.weight, np.ones(10) * 42)
    ckpt = tf.train.Checkpoint(model=self._model)

    self._test_dir = tf.test.get_temp_dir()
    self._ckpt_path = ckpt.save(os.path.join(self._test_dir, 'ckpt'))
    tf.keras.backend.set_value(self._model.weight, np.ones(10))

    pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(
        configs, kwargs_dict=_get_config_kwarg_overrides())
    self._train_input_fn = inputs.create_train_input_fn(
        configs['train_config'],
        configs['train_input_config'],
        configs['model']) 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:model_lib_tf2_test.py

示例2: test_faster_rcnn_resnet50_train_input

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_faster_rcnn_resnet50_train_input(self):
    """Tests the training input function for FasterRcnnResnet50."""
    configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
    model_config = configs['model']
    model_config.faster_rcnn.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        configs['train_config'], configs['train_input_config'], model_config)
    features, labels = _make_initializable_iterator(train_input_fn()).get_next()

    self.assertAllEqual([1, None, None, 3],
                        features[fields.InputDataFields.image].shape.as_list())
    self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
    self.assertAllEqual([1],
                        features[inputs.HASH_KEY].shape.as_list())
    self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
    self.assertAllEqual(
        [1, 100, 4],
        labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_boxes].dtype)
    self.assertAllEqual(
        [1, 100, model_config.faster_rcnn.num_classes],
        labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_classes].dtype)
    self.assertAllEqual(
        [1, 100, model_config.faster_rcnn.num_classes],
        labels[fields.InputDataFields.groundtruth_confidences].shape.as_list())
    self.assertEqual(
        tf.float32,
        labels[fields.InputDataFields.groundtruth_confidences].dtype)
    self.assertAllEqual(
        [1, 100],
        labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_weights].dtype) 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:38,代码来源:inputs_test.py

示例3: test_error_with_bad_train_config

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_error_with_bad_train_config(self):
    """Tests that a TypeError is raised with improper train config."""
    configs = _get_configs_for_model('ssd_inception_v2_pets')
    configs['model'].ssd.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        train_config=configs['eval_config'],  # Expecting `TrainConfig`.
        train_input_config=configs['train_input_config'],
        model_config=configs['model'])
    with self.assertRaises(TypeError):
      train_input_fn() 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:12,代码来源:inputs_test.py

示例4: test_error_with_bad_train_input_config

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_error_with_bad_train_input_config(self):
    """Tests that a TypeError is raised with improper train input config."""
    configs = _get_configs_for_model('ssd_inception_v2_pets')
    configs['model'].ssd.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        train_config=configs['train_config'],
        train_input_config=configs['model'],  # Expecting `InputReader`.
        model_config=configs['model'])
    with self.assertRaises(TypeError):
      train_input_fn() 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:12,代码来源:inputs_test.py

示例5: test_error_with_bad_train_model_config

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_error_with_bad_train_model_config(self):
    """Tests that a TypeError is raised with improper train model config."""
    configs = _get_configs_for_model('ssd_inception_v2_pets')
    configs['model'].ssd.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        train_config=configs['train_config'],
        train_input_config=configs['train_input_config'],
        model_config=configs['train_config'])  # Expecting `DetectionModel`.
    with self.assertRaises(TypeError):
      train_input_fn() 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:12,代码来源:inputs_test.py

示例6: test_faster_rcnn_resnet50_train_input

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_faster_rcnn_resnet50_train_input(self):
    """Tests the training input function for FasterRcnnResnet50."""
    configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
    configs['train_config'].unpad_groundtruth_tensors = True
    model_config = configs['model']
    model_config.faster_rcnn.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        configs['train_config'], configs['train_input_config'], model_config)
    features, labels = train_input_fn()

    self.assertAllEqual([1, None, None, 3],
                        features[fields.InputDataFields.image].shape.as_list())
    self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
    self.assertAllEqual([1],
                        features[inputs.HASH_KEY].shape.as_list())
    self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
    self.assertAllEqual(
        [1, 50, 4],
        labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_boxes].dtype)
    self.assertAllEqual(
        [1, 50, model_config.faster_rcnn.num_classes],
        labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_classes].dtype)
    self.assertAllEqual(
        [1, 50],
        labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_weights].dtype) 
开发者ID:cagbal,项目名称:ros_people_object_detection_tensorflow,代码行数:33,代码来源:inputs_test.py

示例7: test_ssd_inceptionV2_train_input

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_ssd_inceptionV2_train_input(self):
    """Tests the training input function for SSDInceptionV2."""
    configs = _get_configs_for_model('ssd_inception_v2_pets')
    model_config = configs['model']
    model_config.ssd.num_classes = 37
    batch_size = configs['train_config'].batch_size
    train_input_fn = inputs.create_train_input_fn(
        configs['train_config'], configs['train_input_config'], model_config)
    features, labels = train_input_fn()

    self.assertAllEqual([batch_size, 300, 300, 3],
                        features[fields.InputDataFields.image].shape.as_list())
    self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
    self.assertAllEqual([batch_size],
                        features[inputs.HASH_KEY].shape.as_list())
    self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
    self.assertAllEqual(
        [batch_size],
        labels[fields.InputDataFields.num_groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.int32,
                     labels[fields.InputDataFields.num_groundtruth_boxes].dtype)
    self.assertAllEqual(
        [batch_size, 50, 4],
        labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_boxes].dtype)
    self.assertAllEqual(
        [batch_size, 50, model_config.ssd.num_classes],
        labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_classes].dtype)
    self.assertAllEqual(
        [batch_size, 50],
        labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_weights].dtype) 
开发者ID:cagbal,项目名称:ros_people_object_detection_tensorflow,代码行数:38,代码来源:inputs_test.py

示例8: test_faster_rcnn_resnet50_train_input

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_faster_rcnn_resnet50_train_input(self):
    """Tests the training input function for FasterRcnnResnet50."""
    configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
    configs['train_config'].unpad_groundtruth_tensors = True
    model_config = configs['model']
    model_config.faster_rcnn.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        configs['train_config'], configs['train_input_config'], model_config)
    features, labels = train_input_fn()

    self.assertAllEqual([None, None, 3],
                        features[fields.InputDataFields.image].shape.as_list())
    self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
    self.assertAllEqual([],
                        features[inputs.HASH_KEY].shape.as_list())
    self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
    self.assertAllEqual(
        [None, 4],
        labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_boxes].dtype)
    self.assertAllEqual(
        [None, model_config.faster_rcnn.num_classes],
        labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_classes].dtype)
    self.assertAllEqual(
        [None],
        labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_weights].dtype) 
开发者ID:ShreyAmbesh,项目名称:Traffic-Rule-Violation-Detection-System,代码行数:33,代码来源:inputs_test.py

示例9: test_faster_rcnn_resnet50_train_input

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_faster_rcnn_resnet50_train_input(self):
    """Tests the training input function for FasterRcnnResnet50."""
    configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
    model_config = configs['model']
    model_config.faster_rcnn.num_classes = 37
    train_input_fn = inputs.create_train_input_fn(
        configs['train_config'], configs['train_input_config'], model_config)
    features, labels = _make_initializable_iterator(train_input_fn()).get_next()

    self.assertAllEqual([1, None, None, 3],
                        features[fields.InputDataFields.image].shape.as_list())
    self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
    self.assertAllEqual([1],
                        features[inputs.HASH_KEY].shape.as_list())
    self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
    self.assertAllEqual(
        [1, 100, 4],
        labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_boxes].dtype)
    self.assertAllEqual(
        [1, 100, model_config.faster_rcnn.num_classes],
        labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_classes].dtype)
    self.assertAllEqual(
        [1, 100],
        labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_weights].dtype) 
开发者ID:BMW-InnovationLab,项目名称:BMW-TensorFlow-Training-GUI,代码行数:32,代码来源:inputs_test.py

示例10: test_ssd_inceptionV2_train_input

# 需要导入模块: from object_detection import inputs [as 别名]
# 或者: from object_detection.inputs import create_train_input_fn [as 别名]
def test_ssd_inceptionV2_train_input(self):
    """Tests the training input function for SSDInceptionV2."""
    configs = _get_configs_for_model('ssd_inception_v2_pets')
    model_config = configs['model']
    model_config.ssd.num_classes = 37
    batch_size = configs['train_config'].batch_size
    train_input_fn = inputs.create_train_input_fn(
        configs['train_config'], configs['train_input_config'], model_config)
    features, labels = _make_initializable_iterator(train_input_fn()).get_next()

    self.assertAllEqual([batch_size, 300, 300, 3],
                        features[fields.InputDataFields.image].shape.as_list())
    self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
    self.assertAllEqual([batch_size],
                        features[inputs.HASH_KEY].shape.as_list())
    self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
    self.assertAllEqual(
        [batch_size],
        labels[fields.InputDataFields.num_groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.int32,
                     labels[fields.InputDataFields.num_groundtruth_boxes].dtype)
    self.assertAllEqual(
        [batch_size, 100, 4],
        labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_boxes].dtype)
    self.assertAllEqual(
        [batch_size, 100, model_config.ssd.num_classes],
        labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_classes].dtype)
    self.assertAllEqual(
        [batch_size, 100],
        labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
    self.assertEqual(tf.float32,
                     labels[fields.InputDataFields.groundtruth_weights].dtype) 
开发者ID:BMW-InnovationLab,项目名称:BMW-TensorFlow-Training-GUI,代码行数:38,代码来源:inputs_test.py


注:本文中的object_detection.inputs.create_train_input_fn方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。