本文整理汇总了Python中tensorflow.python.training.training_util.get_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_global_step函数的具体用法?Python get_global_step怎么用?Python get_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_global_step函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_get_global_step
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
variables.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
name=ops.GraphKeys.GLOBAL_STEP)
self._assert_global_step(
training_util.get_global_step(), expected_dtype=dtypes.int32)
self._assert_global_step(
training_util.get_global_step(g), expected_dtype=dtypes.int32)
示例2: _ModelFn
def _ModelFn(features, labels, mode):
if is_training:
logits_out = self._BuildGraph(features)
else:
graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
logits_out = importer.import_graph_def(
graph_def,
input_map={INPUT_NODE_NAME: features},
return_elements=[OUTPUT_NODE_NAME + ':0'],
name='')[0]
loss = losses.sparse_softmax_cross_entropy(
labels=labels, logits=logits_out)
summary.scalar('loss', loss)
classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
accuracy = metrics.accuracy(
labels=labels, predictions=classes_out, name='acc_op')
summary.scalar('accuracy', accuracy[1])
if mode == ModeKeys.EVAL:
return EstimatorSpec(
mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
elif mode == ModeKeys.TRAIN:
optimizer = AdamOptimizer(learning_rate=1e-2)
train_op = optimizer.minimize(loss, global_step=get_global_step())
return EstimatorSpec(mode, loss=loss, train_op=train_op)
示例3: _train_op_fn
def _train_op_fn(loss):
"""Returns the op to optimize the loss."""
train_ops = []
global_step = training_util.get_global_step()
if dnn_logits is not None:
train_ops.append(
dnn_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=dnn_parent_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=linear_parent_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):
with ops.colocate_with(global_step):
return state_ops.assign_add(global_step, 1)
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
示例4: record_summaries_every_n_global_steps
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
collection_ref[:] = [training_util.get_global_step() % n == 0]
yield
collection_ref[:] = old
示例5: before_run
def before_run(self, run_context):
loss = (self.loss_op if self.loss_op is not None else
run_context.session.graph.get_operation_by_name(
LOSS_NAME).outputs[0])
return session_run_hook.SessionRunArgs(
{'global_step': training_util.get_global_step(),
'current_loss': loss})
示例6: __init__
def __init__(self):
global_step = training_util.get_global_step()
if global_step:
self._global_step_incr_op = state_ops.assign_add(
global_step, 1, name="global_step_incr").op
else:
self._global_step_incr_op = None
示例7: begin
def begin(self):
self._last_reported_time = None
self._last_reported_step = None
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
示例8: function
def function(tag, scope):
if bad_color is None:
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, bad_color_, max_images,
name=scope)
示例9: get_updates
def get_updates(self, loss, params):
if distribute_lib.has_distribution_strategy():
self.updates = []
if not params:
# After the model vars have been created, the second call to get_updates
# is called with params as an empty list. This ensures that we call
# compute_gradients with params=None.
grads = self.optimizer.compute_gradients(loss)
else:
grads = self.optimizer.compute_gradients(loss, params)
global_step = training_util.get_global_step()
opt_update = self.optimizer.apply_gradients(grads, global_step)
else:
if not params:
self.updates = [state_ops.assign_add(self.iterations, 1)]
return self.updates
# Updates list starts out empty because the iterations variable is
# incremented in optimizer.apply_gradients()
self.updates = []
grads = self.optimizer.compute_gradients(loss, params)
opt_update = self.optimizer.apply_gradients(
grads, global_step=self.iterations)
self.updates.append(opt_update)
return self.updates
示例10: __init__
def __init__(self,
checkpoint_dir,
display_steps=100,
maximum_train_steps=None,
do_summary=True,
is_chief=True):
""" Initializes the hook.
Args:
checkpoint_dir: A string, base directory for the checkpoint files.
display_steps: A python integer, display every N steps.
maximum_train_steps: A python integer, the maximum training steps.
do_summary: Whether to save summaries when display.
is_chief: Whether this is the chief process.do_summary:
"""
tf.logging.info("Create DisplayHook.")
self._checkpoint_dir = checkpoint_dir
# display steps
self._display_steps = display_steps
self._maximum_train_steps = maximum_train_steps
self._do_summary = do_summary
self._is_chief = is_chief # not used now
# display values
global_step = training_util.get_global_step()
display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
self._display_args = dict(zip(display_keys, display_values))
self._display_args["global_step"] = global_step
# timer & summary writer
self._timer = None
self._logging_timer = None
self._summary_writer = None
示例11: begin
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
示例12: begin
def begin(self):
self._last_saved_step = None
self._request_summary = True
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use SummarySaverHook.")
示例13: after_create_session
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
# N.B. We have to pull the global step here to avoid it being unavailable
# at checkpoint time; the graph has been frozen at that point.
if training_util.get_global_step() is None and self.saver() is not None:
raise ValueError(
'Saver defined but no global step. Run `get_or_create_global_step()`'
' in your model definition to allow checkpointing.')
with self._graph.as_default():
logging.info('Installing graceful shutdown hook.')
self._session = _clone_session(training_session, self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
try:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
except errors.InvalidArgumentError:
logging.warn(
'TPU device does not support heartbeats. Failure '
'handling will be disabled.')
self._heartbeat_supported = False
else:
logging.warn(
'No workers support hearbeats. Failure handling will be disabled.')
示例14: _train_op_fn
def _train_op_fn(unused_loss):
global_step = training_util.get_global_step()
sdca_model, train_op = optimizer.get_train_step(
columns_to_variables, weight_column_name, loss_type, features, labels,
global_step)
if update_weights_hook is not None:
update_weights_hook.set_parameters(sdca_model, train_op)
return train_op
示例15: record_summaries_every_n_global_steps
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
with ops.device("cpu:0"):
collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
yield
collection_ref[:] = old