本文整理汇总了Python中tensorflow.contrib.framework.create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python create_global_step函数的具体用法?Python create_global_step怎么用?Python create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了create_global_step函数的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _infer_model
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value for key, value in predictions.items() if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
else:
return self._infer_model_single(
checkpoint_path, predictions, feed_fn, return_dict)
示例2: _infer_model
def _infer_model(self,
x=None, input_fn=None, feed_fn=None,
batch_size=None, axis=None, proba=False):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
if x is not None:
input_fn, feed_fn = _get_predict_input_fn(x, batch_size)
checkpoint_path = saver.latest_checkpoint(self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features, _ = input_fn()
predictions = self._get_predict_ops(features)
if not isinstance(predictions, dict):
predictions = {'predictions': predictions}
# TODO(ipolosukhin): Support batching
if feed_fn is None:
return infer(checkpoint_path, predictions)
preds = {}
while True:
try:
feed_dict = feed_fn()
except StopIteration:
break
if feed_dict is None:
break
outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
for key in outputs:
if key not in preds:
preds[key] = []
preds[key].append(outputs[key])
for key in preds:
preds[key] = np.concatenate(preds[key], axis=0)
return preds
示例3: _infer_model
def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
if x is not None:
input_fn, feed_fn = _get_predict_input_fn(x, None, batch_size)
checkpoint_path = saver.latest_checkpoint(self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
return_dict = True
if not isinstance(predictions, dict):
predictions, return_dict = {'predictions': predictions}, False
if feed_fn is None:
preds = infer(checkpoint_path, predictions)
else:
preds = {}
def _feed_fn():
while True:
yield feed_fn()
outputs = graph_actions.run_feeds(
output_dict=predictions,
feed_dicts=_feed_fn(),
restore_checkpoint_path=checkpoint_path)
for key in predictions:
preds[key] = np.concatenate(
[output[key] for output in outputs], axis=0)
if return_dict:
return preds
return preds['predictions']
示例4: _infer_model
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
# The default return type of _get_predict_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_predict_ops returns a
# `predictions` Tensor or dict or Tensors. The following else-statement
# code covers these cases, but will soon be deleted after the subclasses
# are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
infer_ops = self._get_predict_ops(features)
if isinstance(infer_ops, model_fn_lib.ModelFnOps): # Default signature
predictions = infer_ops.predictions
else: # Legacy signature
predictions = infer_ops
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value
for key, value in six.iteritems(predictions) if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
else:
return self._infer_model_single(
checkpoint_path, predictions, feed_fn, return_dict)
示例5: _evaluate_model
def _evaluate_model(self,
input_fn,
steps,
feed_fn=None,
metrics=None,
name=''):
if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
return
checkpoint_path = self._model_dir
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
eval_dict = self._get_eval_ops(features, targets, metrics)
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, _ = evaluate(graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
eval_dict=eval_dict,
update_op=update_op,
global_step_tensor=global_step,
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results
示例6: _infer_model
def _infer_model(self, x, batch_size=None, axis=None, proba=False):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
input_fn, feed_fn = _get_predict_input_fn(x, batch_size)
checkpoint_path = saver.latest_checkpoint(self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features, _ = input_fn()
feed_dict = feed_fn() if feed_fn is not None else None
predictions = self._get_predict_ops(features)
if not isinstance(predictions, dict):
predictions = {'predictions': predictions}
# TODO(ipolosukhin): Support batching
return infer(checkpoint_path, predictions, feed_dict=feed_dict)
示例7: _infer_model
def _infer_model(self, input_fn, feed_fn=None, outputs=None):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = True
if not isinstance(predictions, dict):
predictions, return_dict = {'predictions': predictions}, False
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value for key, value in predictions.items() if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if feed_fn is None:
preds = graph_actions.infer(checkpoint_path, predictions)
else:
preds = {}
def _feed_fn():
while True:
yield feed_fn()
outputs = graph_actions.run_feeds(
output_dict=predictions,
feed_dicts=_feed_fn(),
restore_checkpoint_path=checkpoint_path)
for key in predictions:
preds[key] = np.concatenate(
[output[key] for output in outputs], axis=0)
if return_dict:
return preds
return preds['predictions']
示例8: _evaluate_model
def _evaluate_model(self,
input_fn,
steps,
feed_fn=None,
metrics=None,
name=''):
# TODO(wicke): Remove this once Model and associated code are gone.
if (hasattr(self._config, 'execution_mode') and
self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
return None, None
# Check that model has been trained.
checkpoint_path = self._model_dir
latest_path = saver.latest_checkpoint(checkpoint_path)
if not latest_path:
raise NotFittedError("Couldn't find trained model at %s."
% checkpoint_path)
# Setup output directory.
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
# The default return type of _get_eval_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_eval_ops returns an
# `eval_dict` dictionary of Tensors. The following else-statement code
# covers these cases, but will soon be deleted after the subclasses are
# updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
eval_ops = self._get_eval_ops(features, labels, metrics)
if isinstance(eval_ops, ModelFnOps): # Default signature
eval_dict = eval_ops.eval_metric_ops
else: # Legacy signature
eval_dict = eval_ops
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, current_global_step = graph_actions.evaluate(
graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
eval_dict=eval_dict,
update_op=update_op,
global_step_tensor=global_step,
supervisor_master=self._config.evaluation_master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results, current_global_step
示例9: _train_model
def _train_model(self,
input_fn,
steps,
feed_fn=None,
device_fn=None,
monitor=None,
log_every_steps=100,
fail_on_nan_loss=True):
if self._config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(self._config.training_worker_max_startup_secs,
self._config.task *
self._config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
self._config.task)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or self._device_fn
with ops.Graph().as_default() as g, g.device(device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
train_op, loss_op = self._get_train_ops(features, targets)
return train(
graph=g,
output_dir=self._model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
log_every_steps=log_every_steps,
supervisor_is_chief=(self._config.task == 0),
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps,
fail_on_nan_loss=fail_on_nan_loss)
示例10: _evaluate_model
def _evaluate_model(self,
input_fn,
steps,
feed_fn=None,
metrics=None,
name=''):
# TODO(wicke): Remove this once Model and associated code are gone.
if (hasattr(self._config, 'execution_mode') and
self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
return None, None
# Check that model has been trained.
checkpoint_path = self._model_dir
latest_path = saver.latest_checkpoint(checkpoint_path)
if not latest_path:
raise NotFittedError("Couldn't find trained model at %s."
% checkpoint_path)
# Setup output directory.
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
eval_dict = self._get_eval_ops(features, targets, metrics)
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, current_global_step = graph_actions.evaluate(
graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
eval_dict=eval_dict,
update_op=update_op,
global_step_tensor=global_step,
supervisor_master=self._config.evaluation_master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results, current_global_step
示例11: _evaluate_model
def _evaluate_model(self, input_fn, steps, feed_fn=None, metrics=None):
if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
return
checkpoint_path = saver.latest_checkpoint(self._model_dir)
eval_dir = os.path.join(self._model_dir, 'eval')
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
eval_dict = self._get_eval_ops(features, targets,
metrics if metrics is not None else
self._get_default_metric_functions())
eval_results, _ = evaluate(
graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
eval_dict=eval_dict,
global_step_tensor=global_step,
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results
示例12: _train_model
def _train_model(self,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True,
max_steps=None):
# TODO(wicke): Remove this once Model and associated code are gone.
if hasattr(self._config, 'execution_mode'):
if self._config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(
self._config.training_worker_max_startup_secs,
self._config.task *
self._config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
self._config.task)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or self._device_fn
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
train_op, loss_op = self._get_train_ops(features, targets)
# Add default monitors.
if monitors is None:
monitors = []
hooks = [m for m in monitors
if isinstance(m, session_run_hook.SessionRunHook)]
deprecated_monitors = [
m for m in monitors
if not isinstance(m, session_run_hook.SessionRunHook)
]
supervisor_is_chief = self._config.is_chief
if not supervisor_is_chief:
# Prune list of monitor to the ones runnable on all workers.
deprecated_monitors = [m for m in deprecated_monitors
if m.run_on_all_workers]
# Setup monitors.
for monitor in deprecated_monitors:
monitor.set_estimator(self)
if deprecated_monitors:
hooks.append(monitor_lib.RunHookAdapterForMonitors(deprecated_monitors))
return graph_actions._monitored_train( # pylint: disable=protected-access
graph=g,
output_dir=self._model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=supervisor_is_chief,
supervisor_master=self._config.master,
supervisor_save_model_secs=self._config.save_checkpoints_secs,
supervisor_save_summaries_steps=self._config.save_summary_steps,
keep_checkpoint_max=self._config.keep_checkpoint_max,
feed_fn=feed_fn,
steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
hooks=hooks,
max_steps=max_steps)
示例13: _train_model
def _train_model(self,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True):
if self._config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(self._config.training_worker_max_startup_secs,
self._config.task *
self._config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
self._config.task)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or self._device_fn
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
train_op, loss_op = self._get_train_ops(features, targets)
# Add default monitors.
if monitors is None:
monitors = []
monitors += monitors_lib.get_default_monitors(
loss_op=loss_op,
summary_op=logging_ops.get_summary_op(),
save_summary_steps=100,
summary_writer=graph_actions.get_summary_writer(self._model_dir))
is_chief = self._config.task == 0
if not is_chief:
# Run monitors only on chief.
monitors = []
# Setup monitors.
for monitor in monitors:
monitor.set_estimator(self)
return train(
graph=g,
output_dir=self._model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=is_chief,
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
monitors=monitors)
示例14: _train_model
def _train_model(self,
input_fn,
steps,
feed_fn=None,
init_op=None,
init_feed_fn=None,
init_fn=None,
device_fn=None,
monitors=None,
log_every_steps=100,
fail_on_nan_loss=True,
max_steps=None):
# TODO(wicke): Remove this once Model and associated code are gone.
if hasattr(self._config, 'execution_mode'):
if self._config.execution_mode not in ('all', 'train'):
return
# Stagger startup of worker sessions based on task id.
sleep_secs = min(
self._config.training_worker_max_startup_secs,
self._config.task *
self._config.training_worker_session_startup_stagger_secs)
if sleep_secs:
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
self._config.task)
time.sleep(sleep_secs)
# Device allocation
device_fn = device_fn or self._device_fn
self._graph = ops.Graph()
with self._graph.as_default() as g, g.device(device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, labels = input_fn()
self._check_inputs(features, labels)
# The default return type of _get_train_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_train_ops returns a
# (train_op, loss) tuple. The following else-statement code covers these
# cases, but will soon be deleted after the subclasses are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
train_ops = self._get_train_ops(features, labels)
if isinstance(train_ops, ModelFnOps): # Default signature
train_op = train_ops.train_op
loss_op = train_ops.loss
else: # Legacy signature
if len(train_ops) != 2:
raise ValueError('Expected a tuple of train_op and loss, got {}'.
format(train_ops))
train_op = train_ops[0]
loss_op = train_ops[1]
hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
return graph_actions._monitored_train( # pylint: disable=protected-access
graph=g,
output_dir=self._model_dir,
train_op=train_op,
loss_op=loss_op,
global_step_tensor=global_step,
init_op=init_op,
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
init_fn=init_fn,
log_every_steps=log_every_steps,
supervisor_is_chief=self.config.is_chief,
supervisor_master=self._config.master,
supervisor_save_model_secs=self._config.save_checkpoints_secs,
supervisor_save_model_steps=self._config.save_checkpoints_steps,
supervisor_save_summaries_steps=self._config.save_summary_steps,
keep_checkpoint_max=self._config.keep_checkpoint_max,
feed_fn=feed_fn,
steps=steps,
fail_on_nan_loss=fail_on_nan_loss,
hooks=hooks,
max_steps=max_steps)