本文整理汇总了Python中tensorflow.python.training.training_util.get_or_create_global_step方法的典型用法代码示例。如果您正苦于以下问题:Python training_util.get_or_create_global_step方法的具体用法?Python training_util.get_or_create_global_step怎么用?Python training_util.get_or_create_global_step使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.training.training_util
的用法示例。
在下文中一共展示了training_util.get_or_create_global_step方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _initialize_in_memory_eval
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def _initialize_in_memory_eval(estimator):
"""Setup TPUEstimator for in-memory evaluation."""
# estimator.evaluate calls _create_global_step unconditionally, override this.
def _global_step(_):
with variable_scope.variable_scope('', use_resource=True):
return training_util.get_or_create_global_step()
estimator._create_global_step = _global_step # pylint: disable=protected-access
estimator._rendezvous[ # pylint: disable=protected-access
model_fn_lib.ModeKeys.EVAL] = error_handling.ErrorRendezvous(3)
estimator._rendezvous[ # pylint: disable=protected-access
model_fn_lib.ModeKeys.PREDICT] = error_handling.ErrorRendezvous(3)
# pylint: disable=protected-access
示例2: get_or_create_global_step
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def get_or_create_global_step(graph=None):
"""Returns and create (if necessary) the global step tensor.
Args:
graph: The graph in which to create the global step tensor. If missing, use
default graph.
Returns:
The global step tensor.
"""
return training_util.get_or_create_global_step(graph)
示例3: begin
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def begin(self):
"""Build eval graph and restoring op."""
self._timer.reset()
self._graph = ops.Graph()
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
with self._graph.as_default():
with variable_scope.variable_scope('', use_resource=True):
training_util.get_or_create_global_step()
features, input_hooks = self._estimator._get_features_from_input_fn( # pylint: disable=protected-access
self._input_fn, model_fn_lib.ModeKeys.PREDICT)
estimator_spec = self._estimator._call_model_fn( # pylint: disable=protected-access
features, None, model_fn_lib.ModeKeys.PREDICT, self._estimator.config)
self._all_hooks = list(input_hooks) + list(estimator_spec.prediction_hooks)
self._predictions = self._estimator._extract_keys( # pylint: disable=protected-access
estimator_spec.predictions,
predict_keys=None)
self._var_name_to_eval_var = {
v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
}
self._var_name_to_placeholder = {
v.name: array_ops.placeholder(v.dtype)
for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
}
logging.info('Placeholders: %s', self._var_name_to_placeholder)
for h in self._all_hooks:
logging.info('Hook: %s', h)
if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
h._should_initialize_tpu = False # pylint: disable=protected-access
示例4: begin
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def begin(self):
if self._replace_summary_op:
# This can still remain None if there are no summaries.
self._summary_op = summary.merge_all()
self._global_step = training_util.get_or_create_global_step()
示例5: _init_global_step
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def _init_global_step(self):
self.global_step = training_util.get_or_create_global_step()
self._training_ops.update({
'increment_global_step': training_util._increment_global_step(1)
})
示例6: apply_gradients
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Wraps the original apply_gradient of the optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
"""
pre_op = self._before_apply_gradients(grads_and_vars)
with ops.control_dependencies([pre_op]):
optimizer_update = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
# We get the default one after calling the super.apply_gradient(), since
# we want to preserve original behavior of the optimizer: don't increment
# anything if no global_step is passed. But we need the global step for
# the mask_update.
global_step = (global_step if global_step is not None
else training_util.get_or_create_global_step())
self._global_step = global_step
with ops.control_dependencies([optimizer_update]):
return self.cond_mask_update_op(global_step, control_flow_ops.no_op)
示例7: _clone_and_build_model
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import get_or_create_global_step [as 别名]
def _clone_and_build_model(mode,
keras_model,
custom_objects,
features=None,
labels=None):
"""Clone and build the given keras_model.
Args:
mode: training mode.
keras_model: an instance of compiled keras model.
custom_objects: Dictionary for custom objects.
features:
labels:
Returns:
The newly built model.
"""
# Set to True during training, False for inference.
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
# Clone keras model.
input_tensors = None if features is None else _create_ordered_io(
keras_model, features)
if custom_objects:
with CustomObjectScope(custom_objects):
model = models.clone_model(keras_model, input_tensors=input_tensors)
else:
model = models.clone_model(keras_model, input_tensors=input_tensors)
# Compile/Build model
if mode is model_fn_lib.ModeKeys.PREDICT and not model.built:
model.build()
else:
optimizer_config = keras_model.optimizer.get_config()
optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
optimizer.iterations = training_util.get_or_create_global_step()
# Get list of outputs.
if labels is None:
target_tensors = None
elif isinstance(labels, dict):
target_tensors = _create_ordered_io(keras_model, labels, is_input=False)
else:
target_tensors = [
sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels)
]
model.compile(
optimizer,
keras_model.loss,
metrics=keras_model.metrics,
loss_weights=keras_model.loss_weights,
sample_weight_mode=keras_model.sample_weight_mode,
weighted_metrics=keras_model.weighted_metrics,
target_tensors=target_tensors)
if isinstance(model, models.Sequential):
model = model.model
return model
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:61,代码来源:estimator.py