当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.train方法代码示例

本文整理汇总了Python中tensorflow.train方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.train方法的具体用法?Python tensorflow.train怎么用?Python tensorflow.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.train方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _create_learning_rate

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def _create_learning_rate(hyperparams, step_var):
  """Creates learning rate var, with decay and switching for CompositeOptimizer.

  Args:
    hyperparams: a GridPoint proto containing optimizer spec, particularly
      learning_method to determine optimizer class to use.
    step_var: tf.Variable, global training step.

  Returns:
    a scalar `Tensor`, the learning rate based on current step and hyperparams.
  """
  if hyperparams.learning_method != 'composite':
    base_rate = hyperparams.learning_rate
  else:
    spec = hyperparams.composite_optimizer_spec
    switch = tf.less(step_var, spec.switch_after_steps)
    base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
                        lambda: tf.constant(spec.method2.learning_rate))
  return tf.train.exponential_decay(
      base_rate,
      step_var,
      hyperparams.decay_steps,
      hyperparams.decay_base,
      staircase=hyperparams.decay_staircase) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:26,代码来源:graph_builder.py

示例2: close

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def close(self):
        """
        Closes the files associated to the TFRecordWriter objects.

        Returns
        -------

        """

        try:
            self.test.close()
        except Exception as e:
            pass

        try:
            self.valid.close()
        except Exception as e:
            pass

        for f in self.train:
            try:
                f.close()
            except Exception as e:
                pass 
开发者ID:imsb-uke,项目名称:scGAN,代码行数:26,代码来源:write_tfrecords.py

示例3: start_server_if_distributed

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn) 
开发者ID:antoine77340,项目名称:Youtube-8M-WILLOW,代码行数:18,代码来源:train.py

示例4: get_meta_filename

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def get_meta_filename(self, start_new_model, train_dir):
    if start_new_model:
      logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                   task_as_string(self.task))
      return None
    
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if not latest_checkpoint: 
      logging.info("%s: No checkpoint file found. Building a new model.",
                   task_as_string(self.task))
      return None
    
    meta_filename = latest_checkpoint + ".meta"
    if not gfile.Exists(meta_filename):
      logging.info("%s: No meta graph file found. Building a new model.",
                     task_as_string(self.task))
      return None
    else:
      return meta_filename 
开发者ID:antoine77340,项目名称:Youtube-8M-WILLOW,代码行数:21,代码来源:train.py

示例5: build_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def build_model(self, model, reader):
    """Find the model and build the graph."""

    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
    optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
  
    build_graph(reader=reader,
                 model=model,
                 optimizer_class=optimizer_class,
                 clip_gradient_norm=FLAGS.clip_gradient_norm,
                 train_data_pattern=FLAGS.train_data_pattern,
                 label_loss_fn=label_loss_fn,
                 base_learning_rate=FLAGS.base_learning_rate,
                 learning_rate_decay=FLAGS.learning_rate_decay,
                 learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
                 regularization_penalty=FLAGS.regularization_penalty,
                 num_readers=FLAGS.num_readers,
                 batch_size=FLAGS.batch_size,
                 num_epochs=FLAGS.num_epochs)
  
    return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=5) 
开发者ID:antoine77340,项目名称:Youtube-8M-WILLOW,代码行数:23,代码来源:train.py

示例6: start_server

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def start_server(cluster, task):
  """Creates a Server.

  Args:
    cluster: A tf.train.ClusterSpec if the execution is distributed.
      None otherwise.
    task: A TaskSpec describing the job type and the task index.
  """

  if not task.type:
    raise ValueError("%s: The task type must be specified." %
                     task_as_string(task))
  if task.index is None:
    raise ValueError("%s: The task index must be specified." %
                     task_as_string(task))

  # Create and start a server.
  return tf.train.Server(
      tf.train.ClusterSpec(cluster),
      protocol="grpc",
      job_name=task.type,
      task_index=task.index) 
开发者ID:antoine77340,项目名称:Youtube-8M-WILLOW,代码行数:24,代码来源:train.py

示例7: get_input_data_tensors

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def get_input_data_tensors(reader,
                           data_pattern,
                           batch_size=256,
                           num_epochs=None):
  logging.info("Using batch size of " + str(batch_size) + " for training.")
  with tf.name_scope("train_input"):
    files = gfile.Glob(data_pattern)
    if not files:
      raise IOError("Unable to find training files. data_pattern='" +
                    data_pattern + "'.")
    logging.info("Number of training files: %s.", str(len(files)))
    files.sort()
    filename_queue = tf.train.string_input_producer(
        files, num_epochs=num_epochs, shuffle=False)
    training_data = reader.prepare_reader(filename_queue)

    return tf.train.batch(
        training_data,
        batch_size=batch_size,
        capacity=FLAGS.batch_size * 4,
        allow_smaller_final_batch=True,
        enqueue_many=True) 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:24,代码来源:train.py

示例8: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:20,代码来源:train.py

示例9: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
        """"Creates a Trainer.

        Args:
          cluster: A tf.train.ClusterSpec if the execution is distributed.
            None otherwise.
          task: A TaskSpec describing the job type and the task index.
        """

        self.cluster = cluster
        self.task = task
        self.is_master = (task.type == "master" and task.index == 0)
        self.train_dir = train_dir
        self.config = tf.ConfigProto(log_device_placement=log_device_placement)

        if self.is_master and self.task.index > 0:
            raise StandardError("%s: Only one replica of master expected",
                                task_as_string(self.task)) 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:20,代码来源:train-with-rebuild.py

示例10: start_server_if_distributed

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def start_server_if_distributed(self):
        """Starts a server if the execution is distributed."""

        if self.cluster:
            logging.info("%s: Starting trainer within cluster %s.",
                         task_as_string(self.task), self.cluster.as_dict())
            server = start_server(self.cluster, self.task)
            target = server.target
            device_fn = tf.train.replica_device_setter(
                ps_device="/job:ps",
                worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
                cluster=self.cluster)
        else:
            target = ""
            device_fn = ""
        return (target, device_fn) 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:18,代码来源:train-with-rebuild.py

示例11: get_meta_filename

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def get_meta_filename(self, start_new_model, train_dir):
        if start_new_model:
            logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                         task_as_string(self.task))
            return None

        latest_checkpoint = tf.train.latest_checkpoint(train_dir)
        if not latest_checkpoint:
            logging.info("%s: No checkpoint file found. Building a new model.",
                         task_as_string(self.task))
            return None

        meta_filename = latest_checkpoint + ".meta"
        if not gfile.Exists(meta_filename):
            logging.info("%s: No meta graph file found. Building a new model.",
                         task_as_string(self.task))
            return None
        else:
            return meta_filename 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:21,代码来源:train-with-rebuild.py

示例12: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement,gpu_options=gpu_options)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:21,代码来源:train_embedding.py

示例13: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task)) 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:21,代码来源:train.py

示例14: add_saver

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def add_saver(self):
    """Adds a Saver for all variables in the graph."""
    logging.info('Saving non-quantized variables:\n\t%s', '\n\t'.join(
        [x.name for x in tf.global_variables() if 'quantized' not in x.name]))
    self.saver = tf.train.Saver(
        var_list=[
            x for x in tf.global_variables() if 'quantized' not in x.name
        ],
        write_version=saver_pb2.SaverDef.V1) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:11,代码来源:graph_builder.py

示例15: visualize

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import train [as 别名]
def visualize(
    logdir, outdir, num_agents, num_episodes, checkpoint=None,
    env_processes=True):
  """Recover checkpoint and render videos from it.

  Args:
    logdir: Logging directory of the trained algorithm.
    outdir: Directory to store rendered videos in.
    num_agents: Number of environments to simulate in parallel.
    num_episodes: Total number of episodes to simulate.
    checkpoint: Checkpoint name to load; defaults to most recent.
    env_processes: Whether to step environments in separate processes.
  """
  config = utility.load_config(logdir)
  with config.unlocked:
    config.network = functools.partial(
        utility.define_network, config.network, config)
    config.policy_optimizer = getattr(tf.train, config.policy_optimizer)
    config.value_optimizer = getattr(tf.train, config.value_optimizer)
  with tf.device('/cpu:0'):
    batch_env = utility.define_batch_env(
        lambda: _create_environment(config, outdir),
        num_agents, env_processes)
    graph = utility.define_simulation_graph(
        batch_env, config.algorithm, config)
    total_steps = num_episodes * config.max_length
    loop = _define_loop(graph, total_steps)
  saver = utility.define_saver(
      exclude=(r'.*_temporary/.*', r'global_step'))
  sess_config = tf.ConfigProto(allow_soft_placement=True)
  sess_config.gpu_options.allow_growth = True
  with tf.Session(config=sess_config) as sess:
    utility.initialize_variables(
        sess, saver, config.logdir, checkpoint, resume=True)
    for unused_score in loop.run(sess, saver, total_steps):
      pass
  batch_env.close() 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:39,代码来源:visualize.py


注:本文中的tensorflow.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。