本文整理匯總了Python中object_detection.legacy.trainer.train方法的典型用法代碼示例。如果您正苦於以下問題:Python trainer.train方法的具體用法?Python trainer.train怎麽用?Python trainer.train使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類object_detection.legacy.trainer
的用法示例。
在下文中一共展示了trainer.train方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: updates
# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def updates(self):
"""Returns a list of update operators for this model.
Returns a list of update operators for this model that must be executed at
each training step. The estimator's train op needs to have a control
dependency on these updates.
Returns:
A list of update operators.
"""
pass
示例2: test_configure_trainer_and_train_two_steps
# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def test_configure_trainer_and_train_two_steps(self):
train_config_text_proto = """
optimizer {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.01
}
}
}
}
data_augmentation_options {
random_adjust_brightness {
max_delta: 0.2
}
}
data_augmentation_options {
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
}
num_steps: 2
"""
train_config = train_pb2.TrainConfig()
text_format.Merge(train_config_text_proto, train_config)
train_dir = self.get_temp_dir()
trainer.train(
create_tensor_dict_fn=get_input_function,
create_model_fn=FakeDetectionModel,
train_config=train_config,
master='',
task=0,
num_clones=1,
worker_replicas=1,
clone_on_cpu=True,
ps_tasks=0,
worker_job_name='worker',
is_chief=True,
train_dir=train_dir)
示例3: test_configure_trainer_with_multiclass_scores_and_train_two_steps
# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
train_config_text_proto = """
optimizer {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.01
}
}
}
}
data_augmentation_options {
random_adjust_brightness {
max_delta: 0.2
}
}
data_augmentation_options {
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
}
num_steps: 2
use_multiclass_scores: true
"""
train_config = train_pb2.TrainConfig()
text_format.Merge(train_config_text_proto, train_config)
train_dir = self.get_temp_dir()
trainer.train(create_tensor_dict_fn=get_input_function,
create_model_fn=FakeDetectionModel,
train_config=train_config,
master='',
task=0,
num_clones=1,
worker_replicas=1,
clone_on_cpu=True,
ps_tasks=0,
worker_job_name='worker',
is_chief=True,
train_dir=train_dir)
示例4: main
# 需要導入模塊: from object_detection.legacy import trainer [as 別名]
# 或者: from object_detection.legacy.trainer import train [as 別名]
def main(_):
assert FLAGS.train_dir, '`train_dir` is missing.'
if FLAGS.pipeline_config_path:
model_config, train_config, input_config = get_configs_from_pipeline_file()
else:
model_config, train_config, input_config = get_configs_from_multiple_files()
model_fn = functools.partial(
model_builder.build,
model_config=model_config,
is_training=True)
create_input_dict_fn = functools.partial(
input_reader_builder.build, input_config)
env = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_data = env.get('cluster', None)
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
task_info = type('TaskSpec', (object,), task_data)
# Parameters for a single worker.
ps_tasks = 0
worker_replicas = 1
worker_job_name = 'lonely_worker'
task = 0
is_chief = True
master = ''
if cluster_data and 'worker' in cluster_data:
# Number of total worker replicas include "worker"s and the "master".
worker_replicas = len(cluster_data['worker']) + 1
if cluster_data and 'ps' in cluster_data:
ps_tasks = len(cluster_data['ps'])
if worker_replicas > 1 and ps_tasks < 1:
raise ValueError('At least 1 ps task is needed for distributed training.')
if worker_replicas >= 1 and ps_tasks > 0:
# Set up distributed training.
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
job_name=task_info.type,
task_index=task_info.index)
if task_info.type == 'ps':
server.join()
return
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
task = task_info.index
is_chief = (task_info.type == 'master')
master = server.target
trainer.train(create_input_dict_fn, model_fn, train_config, master, task,
FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks,
worker_job_name, is_chief, FLAGS.train_dir)