当前位置: 首页>>代码示例>>Python>>正文


Python preprocessing.make_tf_example方法代码示例

本文整理汇总了Python中preprocessing.make_tf_example方法的典型用法代码示例。如果您正苦于以下问题:Python preprocessing.make_tf_example方法的具体用法?Python preprocessing.make_tf_example怎么用?Python preprocessing.make_tf_example使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在preprocessing的用法示例。


在下文中一共展示了preprocessing.make_tf_example方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_rotate_pyfunc

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_rotate_pyfunc(self):
        num_records = 20
        raw_data = self.create_random_data(num_records)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as f:
            preprocessing.write_tf_examples(f.name, tfexamples)

            self.reset_random()
            run_one = self.extract_data(f.name, random_rotation=False)

            self.reset_random()
            run_two = self.extract_data(f.name, random_rotation=True)

            self.reset_random()
            run_three = self.extract_data(f.name, random_rotation=True)

        self.assert_rotate_data(run_one, run_two, run_three) 
开发者ID:mlperf,项目名称:training,代码行数:20,代码来源:test_preprocessing.py

示例2: test_tpu_rotate

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_tpu_rotate(self):
        num_records = 100
        raw_data = self.create_random_data(num_records)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as f:
            preprocessing.write_tf_examples(f.name, tfexamples)

            self.reset_random()
            run_one = self.extract_tpu_data(f.name, random_rotation=False)

            self.reset_random()
            run_two = self.extract_tpu_data(f.name, random_rotation=True)

            self.reset_random()
            run_three = self.extract_tpu_data(f.name, random_rotation=True)

        self.assert_rotate_data(run_one, run_two, run_three) 
开发者ID:mlperf,项目名称:training,代码行数:20,代码来源:test_preprocessing.py

示例3: test_serialize_round_trip

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_serialize_round_trip(self):
        np.random.seed(1)
        raw_data = self.create_random_data(10)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as f:
            preprocessing.write_tf_examples(f.name, tfexamples)
            recovered_data = self.extract_data(f.name)

        self.assertEqualData(raw_data, recovered_data) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:12,代码来源:test_preprocessing.py

示例4: test_filter

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_filter(self):
        raw_data = self.create_random_data(100)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as f:
            preprocessing.write_tf_examples(f.name, tfexamples)
            recovered_data = self.extract_data(f.name, filter_amount=.05)

        # TODO: this will flake out very infrequently.  Use set_random_seed
        self.assertLess(len(recovered_data), 50) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:12,代码来源:test_preprocessing.py

示例5: test_serialize_round_trip_no_parse

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_serialize_round_trip_no_parse(self):
        np.random.seed(1)
        raw_data = self.create_random_data(10)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as start_file, \
                tempfile.NamedTemporaryFile() as rewritten_file:
            preprocessing.write_tf_examples(start_file.name, tfexamples)
            # We want to test that the rewritten, shuffled file contains correctly
            # serialized tf.Examples.
            batch_size = 4
            batches = list(preprocessing.shuffle_tf_examples(
                batch_size, [start_file.name]))
            # 2 batches of 4, 1 incomplete batch of 2.
            self.assertEqual(len(batches), 3)

            # concatenate list of lists into one list
            all_batches = list(itertools.chain.from_iterable(batches))

            for batch in batches:
                preprocessing.write_tf_examples(
                    rewritten_file.name, all_batches, serialize=False)

            original_data = self.extract_data(start_file.name)
            recovered_data = self.extract_data(rewritten_file.name)

        # stuff is shuffled, so sort before checking equality
        def sort_key(nparray_tuple): return nparray_tuple[2]
        original_data = sorted(original_data, key=sort_key)
        recovered_data = sorted(recovered_data, key=sort_key)

        self.assertEqualData(original_data, recovered_data) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:34,代码来源:test_preprocessing.py

示例6: test_serialize_round_trip

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_serialize_round_trip(self):
    np.random.seed(1)
    raw_data = self.create_random_data(10)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as f:
      preprocessing.write_tf_examples(f.name, tfexamples)
      recovered_data = self.extract_data(f.name)

    self.assertEqualData(raw_data, recovered_data) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:12,代码来源:preprocessing_test.py

示例7: test_filter

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_filter(self):
    raw_data = self.create_random_data(100)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as f:
      preprocessing.write_tf_examples(f.name, tfexamples)
      recovered_data = self.extract_data(f.name, filter_amount=.05)

    self.assertLess(len(recovered_data), 50) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:11,代码来源:preprocessing_test.py

示例8: test_serialize_round_trip_no_parse

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def test_serialize_round_trip_no_parse(self):
    np.random.seed(1)
    raw_data = self.create_random_data(10)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as start_file, \
        tempfile.NamedTemporaryFile() as rewritten_file:
      preprocessing.write_tf_examples(start_file.name, tfexamples)
      # We want to test that the rewritten, shuffled file contains correctly
      # serialized tf.Examples.
      batch_size = 4
      batches = list(preprocessing.shuffle_tf_examples(
          1000, batch_size, [start_file.name]))
      # 2 batches of 4, 1 incomplete batch of 2.
      self.assertEqual(len(batches), 3)

      # concatenate list of lists into one list
      all_batches = list(itertools.chain.from_iterable(batches))

      for _ in batches:
        preprocessing.write_tf_examples(
            rewritten_file.name, all_batches, serialize=False)

      original_data = self.extract_data(start_file.name)
      recovered_data = self.extract_data(rewritten_file.name)

    # stuff is shuffled, so sort before checking equality
    def sort_key(nparray_tuple):
      return nparray_tuple[2]
    original_data = sorted(original_data, key=sort_key)
    recovered_data = sorted(recovered_data, key=sort_key)

    self.assertEqualData(original_data, recovered_data) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:35,代码来源:preprocessing_test.py

示例9: convert

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import make_tf_example [as 别名]
def convert(paths):
    position, in_path, out_path = paths
    assert tf.gfile.Exists(in_path)
    assert tf.gfile.Exists(os.path.dirname(out_path))

    in_size = get_size(in_path)
    if tf.gfile.Exists(out_path):
        # Make sure out_path is about the size of in_path
        size = get_size(out_path)
        error = (size - in_size) / (in_size + 1)
        # 5% smaller to 20% larger
        if -0.05 < error < 0.20:
            return out_path + " already existed"
        return "ERROR on file size ({:.1f}% diff) {}".format(
            100 * error, out_path)

    num_batches = dual_net.EXAMPLES_PER_GENERATION // FLAGS.batch_size + 1

    with tf.python_io.TFRecordWriter(out_path, OPTS) as writer:
        record_iter = tqdm(
            batched_reader(in_path),
            desc=os.path.basename(in_path),
            position=position,
            total=num_batches)
        for record in record_iter:
            xs, rs = preprocessing.batch_parse_tf_example(len(record), record)
            # Undo cast in batch_parse_tf_example.
            xs = tf.cast(xs, tf.uint8)

            # map the rotation function.
            x_rot, r_rot = preprocessing._random_rotation(xs, rs)

            with tf.Session() as sess:
                x_rot, r_rot = sess.run([x_rot, r_rot])
            tf.reset_default_graph()

            pi_rot = r_rot['pi_tensor']
            val_rot = r_rot['value_tensor']
            for r, x, pi, val in zip(record, x_rot, pi_rot, val_rot):
                record_out = preprocessing.make_tf_example(x, pi, val)
                serialized = record_out.SerializeToString()
                writer.write(serialized)
                assert len(r) == len(serialized), (len(r), len(serialized)) 
开发者ID:mlperf,项目名称:training,代码行数:45,代码来源:rotate_examples.py


注:本文中的preprocessing.make_tf_example方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。