本文整理匯總了Python中tensorflow.python.training.training_util.create_global_step方法的典型用法代碼示例。如果您正苦於以下問題:Python training_util.create_global_step方法的具體用法?Python training_util.create_global_step怎麽用?Python training_util.create_global_step使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.training.training_util
的用法示例。
在下文中一共展示了training_util.create_global_step方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: create_global_step
# 需要導入模塊: from tensorflow.python.training import training_util [as 別名]
# 或者: from tensorflow.python.training.training_util import create_global_step [as 別名]
def create_global_step(graph=None):
"""Create global step tensor in graph.
This API is deprecated. Use core framework training version instead.
Args:
graph: The graph in which to create the global step tensor. If missing, use
default graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step tensor is already defined.
"""
return training_util.create_global_step(graph)
示例2: create_global_step
# 需要導入模塊: from tensorflow.python.training import training_util [as 別名]
# 或者: from tensorflow.python.training.training_util import create_global_step [as 別名]
def create_global_step(graph=None):
"""Create global step tensor in graph.
This API is deprecated. Use core framework training version instead.
Args:
graph: The graph in which to create the global step tensor. If missing,
use default graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step tensor is already defined.
"""
return training_util.create_global_step(graph)
示例3: _save_first_checkpoint
# 需要導入模塊: from tensorflow.python.training import training_util [as 別名]
# 或者: from tensorflow.python.training.training_util import create_global_step [as 別名]
def _save_first_checkpoint(keras_model, estimator, custom_objects,
keras_weights):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
estimator: keras estimator.
custom_objects: Dictionary for custom objects.
keras_weights: A flat list of Numpy arrays for weights of given keras_model.
Returns:
The model_fn for a keras Estimator.
"""
with ops.Graph().as_default() as g, g.device(estimator._device_fn):
random_seed.set_random_seed(estimator.config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
if isinstance(model, models.Sequential):
model = model.model
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
if not latest_path:
with session.Session() as sess:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
saver.save(sess, estimator.model_dir + '/')
開發者ID:PacktPublishing,項目名稱:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代碼行數:36,代碼來源:estimator.py