本文整理汇总了Python中tensorflow.python.data.ops.dataset_ops.DatasetV2方法的典型用法代码示例。如果您正苦于以下问题:Python dataset_ops.DatasetV2方法的具体用法?Python dataset_ops.DatasetV2怎么用?Python dataset_ops.DatasetV2使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.data.ops.dataset_ops
的用法示例。
在下文中一共展示了dataset_ops.DatasetV2方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: parse_input_fn_result
# 需要导入模块: from tensorflow.python.data.ops import dataset_ops [as 别名]
# 或者: from tensorflow.python.data.ops.dataset_ops import DatasetV2 [as 别名]
def parse_input_fn_result(result):
"""Gets features, labels, and hooks from the result of an Estimator input_fn.
Args:
result: output of an input_fn to an estimator, which should be one of:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple
(features, labels) with same constraints as below.
* A tuple (features, labels): Where `features` is a `Tensor` or a
dictionary of string feature name to `Tensor` and `labels` is a `Tensor`
or a dictionary of string label name to `Tensor`. Both `features` and
`labels` are consumed by `model_fn`. They should satisfy the expectation
of `model_fn` from inputs.
Returns:
Tuple of features, labels, and input_hooks, where features are as described
above, labels are as described above or None, and input_hooks are a list
of SessionRunHooks to be included when running.
Raises:
ValueError: if the result is a list or tuple of length != 2.
"""
input_hooks = []
if isinstance(result, dataset_ops.DatasetV2):
iterator = dataset_ops.make_initializable_iterator(result)
input_hooks.append(_DatasetInitializerHook(iterator))
result = iterator.get_next()
return parse_iterator_result(result) + (input_hooks,)
示例2: from_input_fn
# 需要导入模块: from tensorflow.python.data.ops import dataset_ops [as 别名]
# 或者: from tensorflow.python.data.ops.dataset_ops import DatasetV2 [as 别名]
def from_input_fn(return_values):
"""Returns an `_Inputs` instance according to `input_fn` return value."""
if isinstance(return_values, dataset_ops.DatasetV2):
dataset = return_values
return _Inputs(dataset=dataset)
features, labels = _Inputs._parse_inputs(return_values)
return _Inputs(features, labels)
示例3: on_epoch_end
# 需要导入模块: from tensorflow.python.data.ops import dataset_ops [as 别名]
# 或者: from tensorflow.python.data.ops.dataset_ops import DatasetV2 [as 别名]
def on_epoch_end(self, epoch, logs={}):
'''computing token error'''
cur_session = tf.keras.backend.get_session()
target_seq_list, predict_seq_list = [], []
is_py_sequence = True
if isinstance(self.eval_ds, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)):
eval_gen = self.eval_ds.make_one_shot_iterator()
self.next_batch_gen = eval_gen.get_next()[0]
is_py_sequence = False
elif isinstance(self.eval_ds,
(iterator_ops.IteratorV2, iterator_ops.Iterator)):
self.next_batch_gen = self.ds.get_next()[0]
is_py_sequence = False
for index in range(len(self.eval_task)):
batch_data = None
if is_py_sequence:
batch_data = self.eval_ds[index][0]
else:
batch_data = cur_session.run(self.next_batch_gen)
batch_input = batch_data['inputs']
batch_target = batch_data['targets'].tolist()
batch_predict = self.func(batch_input)[0]
if self.decoder_type == 'argmax':
predict_seq_list += py_ctc.ctc_greedy_decode(
batch_predict, 0, unique=True)
else:
sequence_lens = [len(pre_sequence) for pre_sequence in batch_predict]
batch_decoder, _ = tf_ctc.ctc_beam_search_decode(
tf.constant(batch_predict),
tf.constant(sequence_lens),
beam_width=3,
top_paths=3)
predict_seq_list += cur_session.run(batch_decoder)[0].tolist()
target_seq_list += batch_target
val_token_errors = metrics_lib.token_error(
predict_seq_list=predict_seq_list,
target_seq_list=target_seq_list,
eos_id=0)
logs['val_token_err'] = val_token_errors
if 'val_loss' in logs:
logging.info("Epoch {}: on eval, val_loss is {}.".format(
epoch + 1, logs['val_loss']))
logging.info("Epoch {}: on eval, token_err is {}.".format(
epoch + 1, val_token_errors))
logging.info("Epoch {}: loss on train is {}".format(epoch + 1,
logs['loss']))