本文整理汇总了Python中tensorflow.contrib.tpu.TPUEstimatorSpec方法的典型用法代码示例。如果您正苦于以下问题:Python tpu.TPUEstimatorSpec方法的具体用法?Python tpu.TPUEstimatorSpec怎么用?Python tpu.TPUEstimatorSpec使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.tpu
的用法示例。
在下文中一共展示了tpu.TPUEstimatorSpec方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: model_fn_builder
# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimatorSpec [as 别名]
def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
use_one_hot_embeddings):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
unique_ids = features["unique_ids"]
input_ids = features["input_ids"]
input_mask = features["input_mask"]
input_type_ids = features["input_type_ids"]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
if mode != tf.estimator.ModeKeys.PREDICT:
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
tvars = tf.trainable_variables()
scaffold_fn = None
(assignment_map,
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
all_layers = model.get_all_encoder_layers()
predictions = {
"unique_id": unique_ids,
}
for (i, layer_index) in enumerate(layer_indexes):
predictions["layer_output_%d" % i] = all_layers[layer_index]
output_spec = contrib_tpu.TPUEstimatorSpec(
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
return output_spec
return model_fn
示例2: model_fn_builder
# 需要导入模块: from tensorflow.contrib import tpu [as 别名]
# 或者: from tensorflow.contrib.tpu import TPUEstimatorSpec [as 别名]
def model_fn_builder(num_labels, learning_rate, num_train_steps,
num_warmup_steps, use_tpu, bert_hub_module_handle):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
(total_loss, per_example_loss, logits, probabilities) = create_model(
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
bert_hub_module_handle)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
output_spec = contrib_tpu.TPUEstimatorSpec(
mode=mode, loss=total_loss, train_op=train_op)
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(per_example_loss, label_ids, logits):
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(label_ids, predictions)
loss = tf.metrics.mean(per_example_loss)
return {
"eval_accuracy": accuracy,
"eval_loss": loss,
}
eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
output_spec = contrib_tpu.TPUEstimatorSpec(
mode=mode, loss=total_loss, eval_metrics=eval_metrics)
elif mode == tf.estimator.ModeKeys.PREDICT:
output_spec = contrib_tpu.TPUEstimatorSpec(
mode=mode, predictions={"probabilities": probabilities})
else:
raise ValueError(
"Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode))
return output_spec
return model_fn