當前位置: 首頁>>代碼示例>>Python>>正文


Python dataset_ops.DatasetV2方法代碼示例

本文整理匯總了Python中tensorflow.python.data.ops.dataset_ops.DatasetV2方法的典型用法代碼示例。如果您正苦於以下問題:Python dataset_ops.DatasetV2方法的具體用法?Python dataset_ops.DatasetV2怎麽用?Python dataset_ops.DatasetV2使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.python.data.ops.dataset_ops的用法示例。


在下文中一共展示了dataset_ops.DatasetV2方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: parse_input_fn_result

# 需要導入模塊: from tensorflow.python.data.ops import dataset_ops [as 別名]
# 或者: from tensorflow.python.data.ops.dataset_ops import DatasetV2 [as 別名]
def parse_input_fn_result(result):
  """Gets features, labels, and hooks from the result of an Estimator input_fn.

  Args:
    result: output of an input_fn to an estimator, which should be one of:
      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple
        (features, labels) with same constraints as below.
      * A tuple (features, labels): Where `features` is a `Tensor` or a
        dictionary of string feature name to `Tensor` and `labels` is a `Tensor`
        or a dictionary of string label name to `Tensor`. Both `features` and
        `labels` are consumed by `model_fn`. They should satisfy the expectation
        of `model_fn` from inputs.

  Returns:
    Tuple of features, labels, and input_hooks, where features are as described
    above, labels are as described above or None, and input_hooks are a list
    of SessionRunHooks to be included when running.

  Raises:
    ValueError: if the result is a list or tuple of length != 2.
  """
  input_hooks = []
  if isinstance(result, dataset_ops.DatasetV2):
    iterator = dataset_ops.make_initializable_iterator(result)
    input_hooks.append(_DatasetInitializerHook(iterator))
    result = iterator.get_next()
  return parse_iterator_result(result) + (input_hooks,) 
開發者ID:tensorflow,項目名稱:estimator,代碼行數:29,代碼來源:util.py

示例2: from_input_fn

# 需要導入模塊: from tensorflow.python.data.ops import dataset_ops [as 別名]
# 或者: from tensorflow.python.data.ops.dataset_ops import DatasetV2 [as 別名]
def from_input_fn(return_values):
    """Returns an `_Inputs` instance according to `input_fn` return value."""
    if isinstance(return_values, dataset_ops.DatasetV2):
      dataset = return_values
      return _Inputs(dataset=dataset)

    features, labels = _Inputs._parse_inputs(return_values)
    return _Inputs(features, labels) 
開發者ID:ymcui,項目名稱:Chinese-XLNet,代碼行數:10,代碼來源:tpu_estimator.py

示例3: on_epoch_end

# 需要導入模塊: from tensorflow.python.data.ops import dataset_ops [as 別名]
# 或者: from tensorflow.python.data.ops.dataset_ops import DatasetV2 [as 別名]
def on_epoch_end(self, epoch, logs={}):
    '''computing token error'''

    cur_session = tf.keras.backend.get_session()
    target_seq_list, predict_seq_list = [], []

    is_py_sequence = True
    if isinstance(self.eval_ds, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)):
      eval_gen = self.eval_ds.make_one_shot_iterator()
      self.next_batch_gen = eval_gen.get_next()[0]
      is_py_sequence = False
    elif isinstance(self.eval_ds,
                    (iterator_ops.IteratorV2, iterator_ops.Iterator)):
      self.next_batch_gen = self.ds.get_next()[0]
      is_py_sequence = False

    for index in range(len(self.eval_task)):
      batch_data = None
      if is_py_sequence:
        batch_data = self.eval_ds[index][0]
      else:
        batch_data = cur_session.run(self.next_batch_gen)
      batch_input = batch_data['inputs']
      batch_target = batch_data['targets'].tolist()
      batch_predict = self.func(batch_input)[0]

      if self.decoder_type == 'argmax':
        predict_seq_list += py_ctc.ctc_greedy_decode(
            batch_predict, 0, unique=True)
      else:
        sequence_lens = [len(pre_sequence) for pre_sequence in batch_predict]
        batch_decoder, _ = tf_ctc.ctc_beam_search_decode(
            tf.constant(batch_predict),
            tf.constant(sequence_lens),
            beam_width=3,
            top_paths=3)
        predict_seq_list += cur_session.run(batch_decoder)[0].tolist()
      target_seq_list += batch_target

    val_token_errors = metrics_lib.token_error(
        predict_seq_list=predict_seq_list,
        target_seq_list=target_seq_list,
        eos_id=0)
    logs['val_token_err'] = val_token_errors

    if 'val_loss' in logs:
      logging.info("Epoch {}: on eval, val_loss is {}.".format(
          epoch + 1, logs['val_loss']))
    logging.info("Epoch {}: on eval, token_err is {}.".format(
        epoch + 1, val_token_errors))
    logging.info("Epoch {}: loss on train is {}".format(epoch + 1,
                                                        logs['loss'])) 
開發者ID:didi,項目名稱:delta,代碼行數:54,代碼來源:callbacks.py


注:本文中的tensorflow.python.data.ops.dataset_ops.DatasetV2方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。