本文整理汇总了Python中tensorflow.python.training.training_util.create_global_step方法的典型用法代码示例。如果您正苦于以下问题:Python training_util.create_global_step方法的具体用法?Python training_util.create_global_step怎么用?Python training_util.create_global_step使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.training.training_util
的用法示例。
在下文中一共展示了training_util.create_global_step方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: create_global_step
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import create_global_step [as 别名]
def create_global_step(graph=None):
"""Create global step tensor in graph.
This API is deprecated. Use core framework training version instead.
Args:
graph: The graph in which to create the global step tensor. If missing, use
default graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step tensor is already defined.
"""
return training_util.create_global_step(graph)
示例2: create_global_step
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import create_global_step [as 别名]
def create_global_step(graph=None):
"""Create global step tensor in graph.
This API is deprecated. Use core framework training version instead.
Args:
graph: The graph in which to create the global step tensor. If missing,
use default graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step tensor is already defined.
"""
return training_util.create_global_step(graph)
示例3: _save_first_checkpoint
# 需要导入模块: from tensorflow.python.training import training_util [as 别名]
# 或者: from tensorflow.python.training.training_util import create_global_step [as 别名]
def _save_first_checkpoint(keras_model, estimator, custom_objects,
keras_weights):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
estimator: keras estimator.
custom_objects: Dictionary for custom objects.
keras_weights: A flat list of Numpy arrays for weights of given keras_model.
Returns:
The model_fn for a keras Estimator.
"""
with ops.Graph().as_default() as g, g.device(estimator._device_fn):
random_seed.set_random_seed(estimator.config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
if isinstance(model, models.Sequential):
model = model.model
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
if not latest_path:
with session.Session() as sess:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
saver.save(sess, estimator.model_dir + '/')
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:36,代码来源:estimator.py