当前位置: 首页>>代码示例>>Python>>正文


Python tpu.TPUEstimator方法代码示例

本文整理汇总了Python中tensorflow.contrib.tpu.TPUEstimator方法的典型用法代码示例。如果您正苦于以下问题:Python tpu.TPUEstimator方法的具体用法?Python tpu.TPUEstimator怎么用?Python tpu.TPUEstimator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.tpu的用法示例。


在下文中一共展示了tpu.TPUEstimator方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def main(argv):
  del argv  # Unused.

  # If using update_damping_immediately resource variables must be enabled.
  # (Although they probably will be by default on TPUs.)
  if FLAGS.update_damping_immediately:
    tf.enable_resource_variables()

  tf.set_random_seed(FLAGS.seed)
  # Invert using cholesky decomposition + triangular solve.  This is the only
  # code path for matrix inversion supported on TPU right now.
  kfac.utils.set_global_constants(posdef_inv_method='cholesky')
  kfac.fisher_factors.set_global_constants(
      eigenvalue_decomposition_threshold=10000)

  if not FLAGS.use_sua_approx:
    if FLAGS.use_custom_patches_op:
      kfac.fisher_factors.set_global_constants(
          use_patches_second_moment_op=True
          )
    else:
      # Temporary measure to save memory with giant batches:
      kfac.fisher_factors.set_global_constants(
          sub_sample_inputs=True,
          inputs_to_extract_patches_factor=0.1)

  config = make_tpu_run_config(
      FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop,
      FLAGS.save_checkpoints_steps)

  estimator = contrib_tpu.TPUEstimator(
      use_tpu=True,
      model_fn=_model_fn,
      config=config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=1024)

  estimator.train(
      input_fn=mnist_input_fn,
      max_steps=FLAGS.train_steps,
      hooks=[]) 
开发者ID:tensorflow,项目名称:kfac,代码行数:43,代码来源:classifier_mnist_tpu_estimator.py

示例2: main

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def main(argv):

  if FLAGS.use_control_flow_v2:
    tf.enable_control_flow_v2()

  del argv  # Unused.
  tf.set_random_seed(FLAGS.seed)
  # Invert using cholesky decomposition + triangular solve.  This is the only
  # code path for matrix inversion supported on TPU right now.
  kfac.utils.set_global_constants(posdef_inv_method='cholesky')
  kfac.fisher_factors.set_global_constants(
      eigenvalue_decomposition_threshold=10000)

  config = make_tpu_run_config(
      FLAGS.master, FLAGS.seed, FLAGS.model_dir, FLAGS.iterations_per_loop,
      FLAGS.save_checkpoints_steps)

  estimator = contrib_tpu.TPUEstimator(
      use_tpu=True,
      model_fn=_model_fn,
      config=config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=1024)

  estimator.train(
      input_fn=mnist_input_fn,
      max_steps=FLAGS.train_steps,
      hooks=[]) 
开发者ID:tensorflow,项目名称:kfac,代码行数:30,代码来源:autoencoder_mnist_tpu_estimator.py

示例3: create_estimator

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def create_estimator(t2r_model,
                     model_dir,
                     params = None,
                     **kwargs):
  """Wrapper for Estimator to provide a common interface for instantiation.

  Args:
    t2r_model: An instance of the model we will train or evaluate.
    model_dir: An optional location where we want to store or load our model
      from.
    params: An optional dict of hyper parameters that will be passed into
      input_fn and model_fn. Keys are names of parameters, values are basic
      python types. There are reserved keys for TPUEstimator,
      including 'batch_size'.
    **kwargs: Keyword arguments are only used to enable the same interface for
      tpu estimator and estimator.

  Returns:
    An instance of tf.estimator.Estimator.
  """
  del kwargs
  return tf.estimator.Estimator(
      model_fn=t2r_model.model_fn,
      model_dir=model_dir,
      config=t2r_model.get_run_config(),
      params=params) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:28,代码来源:train_eval.py

示例4: _train_and_eval_reference_model

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def _train_and_eval_reference_model(self, path, multi_dataset=False):
    model_dir = self.create_tempdir().full_path
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
        multi_dataset=multi_dataset)

    # We create a tpu estimator for potential training.
    estimator = contrib_tpu.TPUEstimator(
        model_fn=mock_t2r_model.model_fn,
        use_tpu=mock_t2r_model.is_device_tpu,
        config=contrib_tpu.RunConfig(model_dir=model_dir),
        train_batch_size=BATCH_SIZE,
        eval_batch_size=BATCH_SIZE)

    mock_input_generator = mocks.MockInputGenerator(batch_size=BATCH_SIZE,
                                                    multi_dataset=multi_dataset)
    mock_input_generator.set_specification_from_model(
        mock_t2r_model, tf.estimator.ModeKeys.TRAIN)

    # We optimize our network.
    estimator.train(
        input_fn=mock_input_generator.create_dataset_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN),
        max_steps=MAX_STEPS)

    # Verify that the serving estimator does exactly the same as the normal
    # estimator with all the parameters.
    estimator_predict = tf.estimator.Estimator(
        model_fn=mock_t2r_model.model_fn,
        config=tf.estimator.RunConfig(model_dir=model_dir))

    prediction_ref = estimator_predict.predict(
        input_fn=mock_input_generator.create_dataset_input_fn(
            mode=tf.estimator.ModeKeys.EVAL))

    return model_dir, mock_t2r_model, prediction_ref 
开发者ID:google-research,项目名称:tensor2robot,代码行数:38,代码来源:default_export_generator_test.py

示例5: __init__

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def __init__(self, tokenizer, init_checkpoint):
    """Setup BERT model."""
    self.max_seq_length = FLAGS.max_hotpot_seq_length
    self.max_qry_length = FLAGS.max_hotpot_query_length
    self.batch_size = 1
    self.tokenizer = tokenizer
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    with tf.device("/cpu:0"):
      model_fn = hotpot_model_fn_builder(
          bert_config=bert_config,
          init_checkpoint=init_checkpoint,
          learning_rate=0.0,
          num_train_steps=0,
          num_warmup_steps=0,
          use_tpu=False,
          use_one_hot_embeddings=False)
    run_config = contrib_tpu.RunConfig()
    estimator = contrib_tpu.TPUEstimator(
        use_tpu=False,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=self.batch_size,
        predict_batch_size=self.batch_size)
    self.fast_predictor = FastPredict(estimator,
                                      self.get_input_fn)
    self._PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["start_index", "end_index", "start_logit", "end_logit"]) 
开发者ID:google-research,项目名称:language,代码行数:30,代码来源:demo.py

示例6: create_tpu_estimator

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def create_tpu_estimator(t2r_model,
                         model_dir,
                         train_batch_size = 32,
                         eval_batch_size = 1,
                         use_tpu_hardware = True,
                         params = None,
                         export_to_cpu = True,
                         export_to_tpu = True,
                         **kwargs):
  """Wrapper for TPUEstimator to provide a common interface for instantiation.

  Args:
    t2r_model: An instance of the model we will train or evaluate.
    model_dir: An optional location where we want to store or load our model
      from.
    train_batch_size: The batch size for training.
    eval_batch_size: The batch size for evaluation.
    use_tpu_hardware: If False, the TPUEstimator is used but executed on CPU or
      GPU, depending on availability. This is valuable for debugging, otherwise
      this parameter can be ignored.
    params: An optional dict of hyper parameters that will be passed into
      input_fn and model_fn. Keys are names of parameters, values are basic
      python types. There are reserved keys for TPUEstimator, including
      'batch_size'.
    export_to_cpu: If True, export a savedmodel to cpu.
    export_to_tpu: If True, export a savedmodel to tpu.
    **kwargs: Keyword arguments are only used to enable the same interface for
      tpu estimator and estimator.

  Returns:
    An instance of contrib_tpu.TPUEstimator.
  """
  del kwargs
  return contrib_tpu.TPUEstimator(
      model_fn=t2r_model.model_fn,
      model_dir=model_dir,
      config=t2r_model.get_tpu_run_config(),
      use_tpu=t2r_model.is_device_tpu and use_tpu_hardware,
      train_batch_size=train_batch_size,
      eval_batch_size=eval_batch_size,
      export_to_cpu=export_to_cpu,
      export_to_tpu=export_to_tpu,
      params=params) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:45,代码来源:train_eval.py

示例7: file_based_input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
                                drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "label_ids": tf.FixedLenFeature([], tf.int64),
      "is_real_example": tf.FixedLenFeature([], tf.int64),
  }

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.apply(
        contrib_data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:48,代码来源:run_classifier_membership.py

示例8: input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def input_fn_builder(features, seq_length, is_training, drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  all_input_ids = []
  all_input_mask = []
  all_segment_ids = []
  all_label_ids = []

  for feature in features:
    all_input_ids.append(feature.input_ids)
    all_input_mask.append(feature.input_mask)
    all_segment_ids.append(feature.segment_ids)
    all_label_ids.append(feature.label_id)

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    num_examples = len(features)

    # This is for demo purposes and does NOT scale to large data sets. We do
    # not use Dataset.from_generator() because that uses tf.py_func which is
    # not TPU compatible. The right way to load data is with TFRecordReader.
    d = tf.data.Dataset.from_tensor_slices({
        "input_ids":
            tf.constant(
                all_input_ids, shape=[num_examples, seq_length],
                dtype=tf.int32),
        "input_mask":
            tf.constant(
                all_input_mask,
                shape=[num_examples, seq_length],
                dtype=tf.int32),
        "segment_ids":
            tf.constant(
                all_segment_ids,
                shape=[num_examples, seq_length],
                dtype=tf.int32),
        "label_ids":
            tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
    })

    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
    return d

  return input_fn


# This function is not used by this file but is still used by the Colab and
# people who depend on it. 
开发者ID:google-research,项目名称:language,代码行数:56,代码来源:run_classifier_membership.py

示例9: file_based_input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
                                drop_remainder, num_labels):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "label_ids": tf.FixedLenFeature([num_labels], tf.float32),
      "is_real_example": tf.FixedLenFeature([], tf.int64),
  }

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.apply(
        contrib_data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:48,代码来源:run_classifier_distillation.py

示例10: input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def input_fn_builder(features, seq_length, is_training, drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  all_input_ids = []
  all_input_mask = []
  all_segment_ids = []
  all_label_ids = []

  for feature in features:
    all_input_ids.append(feature.input_ids)
    all_input_mask.append(feature.input_mask)
    all_segment_ids.append(feature.segment_ids)
    all_label_ids.append(feature.label_id)

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    num_examples = len(features)

    # This is for demo purposes and does NOT scale to large data sets. We do
    # not use Dataset.from_generator() because that uses tf.py_func which is
    # not TPU compatible. The right way to load data is with TFRecordReader.
    d = tf.data.Dataset.from_tensor_slices({
        "input_ids":
            tf.constant(
                all_input_ids, shape=[num_examples, seq_length],
                dtype=tf.int32),
        "input_mask":
            tf.constant(
                all_input_mask,
                shape=[num_examples, seq_length],
                dtype=tf.int32),
        "segment_ids":
            tf.constant(
                all_segment_ids,
                shape=[num_examples, seq_length],
                dtype=tf.int32),
        "label_ids":
            tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
    })

    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:52,代码来源:run_classifier_distillation.py

示例11: file_based_input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def file_based_input_fn_builder(input_file, seq_length, is_training,
                                drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "label_ids": tf.FixedLenFeature([], tf.int64),
      "probs": tf.FixedLenFeature([2], tf.float32)
  }

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.apply(
        contrib_data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:48,代码来源:run_bert_boolq_diff.py

示例12: input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def input_fn_builder(features, seq_length, is_training, drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  all_input_ids = []
  all_input_mask = []
  all_segment_ids = []
  all_label_ids = []
  all_probs = []

  for feature in features:
    all_input_ids.append(feature.input_ids)
    all_input_mask.append(feature.input_mask)
    all_segment_ids.append(feature.segment_ids)
    all_label_ids.append(feature.label_id)
    all_probs.append(feature.probs)

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    num_examples = len(features)

    # This is for demo purposes and does NOT scale to large data sets. We do
    # not use Dataset.from_generator() because that uses tf.py_func which is
    # not TPU compatible. The right way to load data is with TFRecordReader.
    d = tf.data.Dataset.from_tensor_slices({
        "input_ids":
            tf.constant(
                all_input_ids, shape=[num_examples, seq_length],
                dtype=tf.int32),
        "input_mask":
            tf.constant(
                all_input_mask,
                shape=[num_examples, seq_length],
                dtype=tf.int32),
        "segment_ids":
            tf.constant(
                all_segment_ids,
                shape=[num_examples, seq_length],
                dtype=tf.int32),
        "label_ids":
            tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
        "probs":
            tf.constant(all_probs, shape=[num_examples, 2], dtype=tf.float32)
    })

    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
    return d

  return input_fn


# This function is not used by this file but is still used by the Colab and
# people who depend on it. 
开发者ID:google-research,项目名称:language,代码行数:60,代码来源:run_bert_boolq_diff.py

示例13: input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  name_to_features = {
      "unique_ids": tf.FixedLenFeature([], tf.int64),
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
  }

  if is_training:
    name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
    name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.apply(
        contrib_data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:50,代码来源:run_squad.py

示例14: input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  name_to_features = {
      "unique_ids": tf.FixedLenFeature([], tf.int64),
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
  }

  if is_training:
    name_to_features["label_ids"] = tf.FixedLenFeature([], tf.int64)

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)

    d = d.apply(
        contrib_data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:49,代码来源:run_squad_membership.py

示例15: file_based_input_fn_builder

# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimator [as 别名]
def file_based_input_fn_builder(input_file,
                                seq_length,
                                is_training,
                                drop_remainder,
                                skip=0):
  """Creates an `input_fn` closure to be passed to TPUEstimator."""

  try:
    input_file = tf.io.gfile.glob(input_file)
  except tf.errors.OpError:
    pass  # if it's not a sharded file just keep it as is

  name_to_features = {
      "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
      "next_sentence_labels": tf.FixedLenFeature([], tf.int64),
  }

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)
    example["label_ids"] = example["next_sentence_labels"]

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn(params):
    """The actual input function."""
    batch_size = params["batch_size"]

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle(buffer_size=100)
      # Every iteration the input_fn() gets called so the dataset starts from 0
      # by skipping we won't repeat the same data
      d = d.skip(skip)

    d = d.apply(
        tf.data.experimental.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn 
开发者ID:google-research,项目名称:language,代码行数:59,代码来源:run_binary_coherence.py


注:本文中的tensorflow.contrib.tpu.TPUEstimator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。