本文整理汇总了Python中tensorflow.train方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.train方法的具体用法?Python tensorflow.train怎么用?Python tensorflow.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.train方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _create_learning_rate
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def _create_learning_rate(hyperparams, step_var):
"""Creates learning rate var, with decay and switching for CompositeOptimizer.
Args:
hyperparams: a GridPoint proto containing optimizer spec, particularly
learning_method to determine optimizer class to use.
step_var: tf.Variable, global training step.
Returns:
a scalar `Tensor`, the learning rate based on current step and hyperparams.
"""
if hyperparams.learning_method != 'composite':
base_rate = hyperparams.learning_rate
else:
spec = hyperparams.composite_optimizer_spec
switch = tf.less(step_var, spec.switch_after_steps)
base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
lambda: tf.constant(spec.method2.learning_rate))
return tf.train.exponential_decay(
base_rate,
step_var,
hyperparams.decay_steps,
hyperparams.decay_base,
staircase=hyperparams.decay_staircase)
示例2: close
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def close(self):
"""
Closes the files associated to the TFRecordWriter objects.
Returns
-------
"""
try:
self.test.close()
except Exception as e:
pass
try:
self.valid.close()
except Exception as e:
pass
for f in self.train:
try:
f.close()
except Exception as e:
pass
示例3: start_server_if_distributed
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
示例4: get_meta_filename
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
示例5: build_model
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def build_model(self, model, reader):
"""Find the model and build the graph."""
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
build_graph(reader=reader,
model=model,
optimizer_class=optimizer_class,
clip_gradient_norm=FLAGS.clip_gradient_norm,
train_data_pattern=FLAGS.train_data_pattern,
label_loss_fn=label_loss_fn,
base_learning_rate=FLAGS.base_learning_rate,
learning_rate_decay=FLAGS.learning_rate_decay,
learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
regularization_penalty=FLAGS.regularization_penalty,
num_readers=FLAGS.num_readers,
batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=5)
示例6: start_server
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def start_server(cluster, task):
"""Creates a Server.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
if not task.type:
raise ValueError("%s: The task type must be specified." %
task_as_string(task))
if task.index is None:
raise ValueError("%s: The task index must be specified." %
task_as_string(task))
# Create and start a server.
return tf.train.Server(
tf.train.ClusterSpec(cluster),
protocol="grpc",
job_name=task.type,
task_index=task.index)
示例7: get_input_data_tensors
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def get_input_data_tensors(reader,
data_pattern,
batch_size=256,
num_epochs=None):
logging.info("Using batch size of " + str(batch_size) + " for training.")
with tf.name_scope("train_input"):
files = gfile.Glob(data_pattern)
if not files:
raise IOError("Unable to find training files. data_pattern='" +
data_pattern + "'.")
logging.info("Number of training files: %s.", str(len(files)))
files.sort()
filename_queue = tf.train.string_input_producer(
files, num_epochs=num_epochs, shuffle=False)
training_data = reader.prepare_reader(filename_queue)
return tf.train.batch(
training_data,
batch_size=batch_size,
capacity=FLAGS.batch_size * 4,
allow_smaller_final_batch=True,
enqueue_many=True)
示例8: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
""""Creates a Trainer.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
self.cluster = cluster
self.task = task
self.is_master = (task.type == "master" and task.index == 0)
self.train_dir = train_dir
self.config = tf.ConfigProto(log_device_placement=log_device_placement)
if self.is_master and self.task.index > 0:
raise StandardError("%s: Only one replica of master expected",
task_as_string(self.task))
示例9: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
""""Creates a Trainer.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
self.cluster = cluster
self.task = task
self.is_master = (task.type == "master" and task.index == 0)
self.train_dir = train_dir
self.config = tf.ConfigProto(log_device_placement=log_device_placement)
if self.is_master and self.task.index > 0:
raise StandardError("%s: Only one replica of master expected",
task_as_string(self.task))
示例10: start_server_if_distributed
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def start_server_if_distributed(self):
"""Starts a server if the execution is distributed."""
if self.cluster:
logging.info("%s: Starting trainer within cluster %s.",
task_as_string(self.task), self.cluster.as_dict())
server = start_server(self.cluster, self.task)
target = server.target
device_fn = tf.train.replica_device_setter(
ps_device="/job:ps",
worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
cluster=self.cluster)
else:
target = ""
device_fn = ""
return (target, device_fn)
示例11: get_meta_filename
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def get_meta_filename(self, start_new_model, train_dir):
if start_new_model:
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
task_as_string(self.task))
return None
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if not latest_checkpoint:
logging.info("%s: No checkpoint file found. Building a new model.",
task_as_string(self.task))
return None
meta_filename = latest_checkpoint + ".meta"
if not gfile.Exists(meta_filename):
logging.info("%s: No meta graph file found. Building a new model.",
task_as_string(self.task))
return None
else:
return meta_filename
示例12: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
""""Creates a Trainer.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
self.cluster = cluster
self.task = task
self.is_master = (task.type == "master" and task.index == 0)
self.train_dir = train_dir
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
self.config = tf.ConfigProto(log_device_placement=log_device_placement,gpu_options=gpu_options)
if self.is_master and self.task.index > 0:
raise StandardError("%s: Only one replica of master expected",
task_as_string(self.task))
示例13: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
""""Creates a Trainer.
Args:
cluster: A tf.train.ClusterSpec if the execution is distributed.
None otherwise.
task: A TaskSpec describing the job type and the task index.
"""
self.cluster = cluster
self.task = task
self.is_master = (task.type == "master" and task.index == 0)
self.train_dir = train_dir
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu)
self.config = tf.ConfigProto(log_device_placement=log_device_placement)
if self.is_master and self.task.index > 0:
raise StandardError("%s: Only one replica of master expected",
task_as_string(self.task))
示例14: add_saver
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def add_saver(self):
"""Adds a Saver for all variables in the graph."""
logging.info('Saving non-quantized variables:\n\t%s', '\n\t'.join(
[x.name for x in tf.global_variables() if 'quantized' not in x.name]))
self.saver = tf.train.Saver(
var_list=[
x for x in tf.global_variables() if 'quantized' not in x.name
],
write_version=saver_pb2.SaverDef.V1)
示例15: visualize
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def visualize(
logdir, outdir, num_agents, num_episodes, checkpoint=None,
env_processes=True):
"""Recover checkpoint and render videos from it.
Args:
logdir: Logging directory of the trained algorithm.
outdir: Directory to store rendered videos in.
num_agents: Number of environments to simulate in parallel.
num_episodes: Total number of episodes to simulate.
checkpoint: Checkpoint name to load; defaults to most recent.
env_processes: Whether to step environments in separate processes.
"""
config = utility.load_config(logdir)
with config.unlocked:
config.network = functools.partial(
utility.define_network, config.network, config)
config.policy_optimizer = getattr(tf.train, config.policy_optimizer)
config.value_optimizer = getattr(tf.train, config.value_optimizer)
with tf.device('/cpu:0'):
batch_env = utility.define_batch_env(
lambda: _create_environment(config, outdir),
num_agents, env_processes)
graph = utility.define_simulation_graph(
batch_env, config.algorithm, config)
total_steps = num_episodes * config.max_length
loop = _define_loop(graph, total_steps)
saver = utility.define_saver(
exclude=(r'.*_temporary/.*', r'global_step'))
sess_config = tf.ConfigProto(allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True
with tf.Session(config=sess_config) as sess:
utility.initialize_variables(
sess, saver, config.logdir, checkpoint, resume=True)
for unused_score in loop.run(sess, saver, total_steps):
pass
batch_env.close()