當前位置: 首頁>>代碼示例>>Python>>正文


Python trainer.train方法代碼示例

本文整理匯總了Python中object_detection.legacy.trainer.train方法的典型用法代碼示例。如果您正苦於以下問題:Python trainer.train方法的具體用法?Python trainer.train怎麽用?Python trainer.train使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在object_detection.legacy.trainer的用法示例。


在下文中一共展示了trainer.train方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: updates

# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def updates(self):
    """Returns a list of update operators for this model.

    Returns a list of update operators for this model that must be executed at
    each training step. The estimator's train op needs to have a control
    dependency on these updates.

    Returns:
      A list of update operators.
    """
    pass 
開發者ID:ahmetozlu,項目名稱:vehicle_counting_tensorflow,代碼行數:13,代碼來源:trainer_test.py

示例2: test_configure_trainer_and_train_two_steps

# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def test_configure_trainer_and_train_two_steps(self):
    train_config_text_proto = """
    optimizer {
      adam_optimizer {
        learning_rate {
          constant_learning_rate {
            learning_rate: 0.01
          }
        }
      }
    }
    data_augmentation_options {
      random_adjust_brightness {
        max_delta: 0.2
      }
    }
    data_augmentation_options {
      random_adjust_contrast {
        min_delta: 0.7
        max_delta: 1.1
      }
    }
    num_steps: 2
    """
    train_config = train_pb2.TrainConfig()
    text_format.Merge(train_config_text_proto, train_config)

    train_dir = self.get_temp_dir()

    trainer.train(
        create_tensor_dict_fn=get_input_function,
        create_model_fn=FakeDetectionModel,
        train_config=train_config,
        master='',
        task=0,
        num_clones=1,
        worker_replicas=1,
        clone_on_cpu=True,
        ps_tasks=0,
        worker_job_name='worker',
        is_chief=True,
        train_dir=train_dir) 
開發者ID:ahmetozlu,項目名稱:vehicle_counting_tensorflow,代碼行數:44,代碼來源:trainer_test.py

示例3: test_configure_trainer_with_multiclass_scores_and_train_two_steps

# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
    train_config_text_proto = """
    optimizer {
      adam_optimizer {
        learning_rate {
          constant_learning_rate {
            learning_rate: 0.01
          }
        }
      }
    }
    data_augmentation_options {
      random_adjust_brightness {
        max_delta: 0.2
      }
    }
    data_augmentation_options {
      random_adjust_contrast {
        min_delta: 0.7
        max_delta: 1.1
      }
    }
    num_steps: 2
    use_multiclass_scores: true
    """
    train_config = train_pb2.TrainConfig()
    text_format.Merge(train_config_text_proto, train_config)

    train_dir = self.get_temp_dir()

    trainer.train(create_tensor_dict_fn=get_input_function,
                  create_model_fn=FakeDetectionModel,
                  train_config=train_config,
                  master='',
                  task=0,
                  num_clones=1,
                  worker_replicas=1,
                  clone_on_cpu=True,
                  ps_tasks=0,
                  worker_job_name='worker',
                  is_chief=True,
                  train_dir=train_dir) 
開發者ID:ahmetozlu,項目名稱:vehicle_counting_tensorflow,代碼行數:44,代碼來源:trainer_test.py

示例4: main

# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def main(_):
  assert FLAGS.train_dir, '`train_dir` is missing.'
  if FLAGS.pipeline_config_path:
    model_config, train_config, input_config = get_configs_from_pipeline_file()
  else:
    model_config, train_config, input_config = get_configs_from_multiple_files()

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=True)

  create_input_dict_fn = functools.partial(
      input_reader_builder.build, input_config)

  env = json.loads(os.environ.get('TF_CONFIG', '{}'))
  cluster_data = env.get('cluster', None)
  cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
  task_data = env.get('task', None) or {'type': 'master', 'index': 0}
  task_info = type('TaskSpec', (object,), task_data)

  # Parameters for a single worker.
  ps_tasks = 0
  worker_replicas = 1
  worker_job_name = 'lonely_worker'
  task = 0
  is_chief = True
  master = ''

  if cluster_data and 'worker' in cluster_data:
    # Number of total worker replicas include "worker"s and the "master".
    worker_replicas = len(cluster_data['worker']) + 1
  if cluster_data and 'ps' in cluster_data:
    ps_tasks = len(cluster_data['ps'])

  if worker_replicas > 1 and ps_tasks < 1:
    raise ValueError('At least 1 ps task is needed for distributed training.')

  if worker_replicas >= 1 and ps_tasks > 0:
    # Set up distributed training.
    server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
                             job_name=task_info.type,
                             task_index=task_info.index)
    if task_info.type == 'ps':
      server.join()
      return

    worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
    task = task_info.index
    is_chief = (task_info.type == 'master')
    master = server.target

  trainer.train(create_input_dict_fn, model_fn, train_config, master, task,
                FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks,
                worker_job_name, is_chief, FLAGS.train_dir) 
開發者ID:maartensukel,項目名稱:garbage-object-detection-tensorflow,代碼行數:57,代碼來源:train.py


注:本文中的object_detection.legacy.trainer.train方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。