本文整理汇总了Python中tensorflow.contrib.slim.python.slim.data.tfexample_decoder.ItemHandlerCallback方法的典型用法代码示例。如果您正苦于以下问题:Python tfexample_decoder.ItemHandlerCallback方法的具体用法?Python tfexample_decoder.ItemHandlerCallback怎么用?Python tfexample_decoder.ItemHandlerCallback使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.slim.python.slim.data.tfexample_decoder
的用法示例。
在下文中一共展示了tfexample_decoder.ItemHandlerCallback方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testDecodeExampleWithItemHandlerCallback
# 需要导入模块: from tensorflow.contrib.slim.python.slim.data import tfexample_decoder [as 别名]
# 或者: from tensorflow.contrib.slim.python.slim.data.tfexample_decoder import ItemHandlerCallback [as 别名]
def testDecodeExampleWithItemHandlerCallback(self):
np.random.seed(0)
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
example = example_pb2.Example(features=feature_pb2.Features(feature={
'image/depth_map': self._EncodedFloatFeature(np_array),
}))
serialized_example = example.SerializeToString()
with self.test_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'image/depth_map':
parsing_ops.FixedLenFeature(
tensor_shape,
dtypes.float32,
default_value=array_ops.zeros(tensor_shape))
}
def HandleDepth(keys_to_tensors):
depth = list(keys_to_tensors.values())[0]
depth += 1
return depth
items_to_handlers = {
'depth':
tfexample_decoder.ItemHandlerCallback('image/depth_map',
HandleDepth)
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_depth] = decoder.decode(serialized_example, ['depth'])
depth = tf_depth.eval()
self.assertAllClose(np_array, depth - 1)
示例2: testDecodeImageWithItemHandlerCallback
# 需要导入模块: from tensorflow.contrib.slim.python.slim.data import tfexample_decoder [as 别名]
# 或者: from tensorflow.contrib.slim.python.slim.data.tfexample_decoder import ItemHandlerCallback [as 别名]
def testDecodeImageWithItemHandlerCallback(self):
image_shape = (2, 3, 3)
for image_encoding in ['jpeg', 'png']:
image, serialized_example = self.GenerateImage(
image_format=image_encoding, image_shape=image_shape)
with self.test_session():
def ConditionalDecoding(keys_to_tensors):
"""See base class."""
image_buffer = keys_to_tensors['image/encoded']
image_format = keys_to_tensors['image/format']
def DecodePng():
return image_ops.decode_png(image_buffer, 3)
def DecodeJpg():
return image_ops.decode_jpeg(image_buffer, 3)
image = control_flow_ops.case(
{
math_ops.equal(image_format, 'png'): DecodePng,
},
default=DecodeJpg,
exclusive=True)
image = array_ops.reshape(image, image_shape)
return image
keys_to_features = {
'image/encoded':
parsing_ops.FixedLenFeature(
(), dtypes.string, default_value=''),
'image/format':
parsing_ops.FixedLenFeature(
(), dtypes.string, default_value='jpeg')
}
items_to_handlers = {
'image':
tfexample_decoder.ItemHandlerCallback(
['image/encoded', 'image/format'], ConditionalDecoding)
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_image] = decoder.decode(serialized_example, ['image'])
decoded_image = tf_image.eval()
if image_encoding == 'jpeg':
# For jenkins:
image = image.astype(np.float32)
decoded_image = decoded_image.astype(np.float32)
self.assertAllClose(image, decoded_image, rtol=.5, atol=1.001)
else:
self.assertAllClose(image, decoded_image, atol=0)
示例3: make_data_provider
# 需要导入模块: from tensorflow.contrib.slim.python.slim.data import tfexample_decoder [as 别名]
# 或者: from tensorflow.contrib.slim.python.slim.data.tfexample_decoder import ItemHandlerCallback [as 别名]
def make_data_provider(self, **kwargs):
splitter_source = split_tokens_decoder.SplitTokensDecoder(
tokens_feature_name="source_tokens",
length_feature_name="source_len",
append_token="SEQUENCE_END",
delimiter=self.params["source_delimiter"])
splitter_target = split_tokens_decoder.SplitTokensDecoder(
tokens_feature_name="target_tokens",
length_feature_name="target_len",
prepend_token="SEQUENCE_START",
append_token="SEQUENCE_END",
delimiter=self.params["target_delimiter"])
keys_to_features = {
self.params["source_field"]: tf.FixedLenFeature((), tf.string),
self.params["target_field"]: tf.FixedLenFeature(
(), tf.string, default_value="")
}
items_to_handlers = {}
items_to_handlers["source_tokens"] = tfexample_decoder.ItemHandlerCallback(
keys=[self.params["source_field"]],
func=lambda dict: splitter_source.decode(
dict[self.params["source_field"]], ["source_tokens"])[0])
items_to_handlers["source_len"] = tfexample_decoder.ItemHandlerCallback(
keys=[self.params["source_field"]],
func=lambda dict: splitter_source.decode(
dict[self.params["source_field"]], ["source_len"])[0])
items_to_handlers["target_tokens"] = tfexample_decoder.ItemHandlerCallback(
keys=[self.params["target_field"]],
func=lambda dict: splitter_target.decode(
dict[self.params["target_field"]], ["target_tokens"])[0])
items_to_handlers["target_len"] = tfexample_decoder.ItemHandlerCallback(
keys=[self.params["target_field"]],
func=lambda dict: splitter_target.decode(
dict[self.params["target_field"]], ["target_len"])[0])
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
dataset = tf.contrib.slim.dataset.Dataset(
data_sources=self.params["files"],
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=None,
items_to_descriptions={})
return tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
dataset=dataset,
shuffle=self.params["shuffle"],
num_epochs=self.params["num_epochs"],
**kwargs)
开发者ID:akanimax,项目名称:natural-language-summary-generation-from-structured-data,代码行数:56,代码来源:input_pipeline.py