当前位置: 首页>>代码示例>>Python>>正文


Python common_flags.create_model方法代码示例

本文整理汇总了Python中common_flags.create_model方法的典型用法代码示例。如果您正苦于以下问题:Python common_flags.create_model方法的具体用法?Python common_flags.create_model怎么用?Python common_flags.create_model使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在common_flags的用法示例。


在下文中一共展示了common_flags.create_model方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def main(_):
  if not tf.gfile.Exists(FLAGS.eval_log_dir):
    tf.gfile.MakeDirs(FLAGS.eval_log_dir)

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())
  endpoints = model.create_base(data.images, labels_one_hot=None)
  model.create_loss(data, endpoints)
  eval_ops = model.create_summaries(
      data, endpoints, dataset.charset, is_training=False)
  slim.get_or_create_global_step()
  session_config = tf.ConfigProto(device_count={"GPU": 0})
  slim.evaluation.evaluation_loop(
      master=FLAGS.master,
      checkpoint_dir=FLAGS.train_log_dir,
      logdir=FLAGS.eval_log_dir,
      eval_op=eval_ops,
      num_evals=FLAGS.num_batches,
      eval_interval_secs=FLAGS.eval_interval_secs,
      max_number_of_evaluations=FLAGS.number_of_steps,
      session_config=session_config) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:30,代码来源:eval.py

示例2: main

# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def main(_):
  prepare_training_dir()

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  hparams = get_training_hparams()

  # If ps_tasks is zero, the local device is used. When using multiple
  # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
  # across the different devices.
  device_setter = tf.train.replica_device_setter(
      FLAGS.ps_tasks, merge_devices=True)
  with tf.device(device_setter):
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=hparams.use_augment_input,
        central_crop_size=common_flags.get_crop_size())
    endpoints = model.create_base(data.images, data.labels_one_hot)
    total_loss = model.create_loss(data, endpoints)
    model.create_summaries(data, endpoints, dataset.charset, is_training=True)
    init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                              FLAGS.checkpoint_inception)
    if FLAGS.show_graph_stats:
      logging.info('Total number of weights in the graph: %s',
                   calculate_graph_metrics())
    train(total_loss, init_fn, hparams) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:31,代码来源:train.py

示例3: create_model

# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def create_model(batch_size, dataset_name):
  width, height = get_dataset_image_size(dataset_name)
  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(
    num_char_classes=dataset.num_char_classes,
    seq_length=dataset.max_sequence_length,
    num_views=dataset.num_of_views,
    null_code=dataset.null_code,
    charset=dataset.charset)
  raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
  images = tf.map_fn(data_provider.preprocess_image, raw_images,
                     dtype=tf.float32)
  endpoints = model.create_base(images, labels_one_hot=None)
  return raw_images, endpoints 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:16,代码来源:demo_inference.py

示例4: run

# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
  images_placeholder, endpoints = create_model(batch_size,
                                               dataset_name)
  images_data = load_images(image_path_pattern, batch_size,
                            dataset_name)
  session_creator = monitored_session.ChiefSessionCreator(
    checkpoint_filename_with_path=checkpoint)
  with monitored_session.MonitoredSession(
      session_creator=session_creator) as sess:
    predictions = sess.run(endpoints.predicted_text,
                           feed_dict={images_placeholder: images_data})
  return predictions.tolist() 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:14,代码来源:demo_inference.py

示例5: load_model

# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def load_model(checkpoint, batch_size, dataset_name):
  width, height = get_dataset_image_size(dataset_name)
  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(
      num_char_classes=dataset.num_char_classes,
      seq_length=dataset.max_sequence_length,
      num_views=dataset.num_of_views,
      null_code=dataset.null_code,
      charset=dataset.charset)
  images_placeholder = tf.placeholder(tf.float32,
                                      shape=[batch_size, height, width, 3])
  endpoints = model.create_base(images_placeholder, labels_one_hot=None)
  init_fn = model.create_init_fn_to_restore(checkpoint)
  return images_placeholder, endpoints, init_fn 
开发者ID:sshleifer,项目名称:object_detection_kitti,代码行数:16,代码来源:demo_inference.py

示例6: run

# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
  images_placeholder, endpoints = create_model(batch_size,
                                               dataset_name)
  images_data = load_images(image_path_pattern, batch_size,
                            dataset_name)
  session_creator = monitored_session.ChiefSessionCreator(
    checkpoint_filename_with_path=checkpoint)
  with monitored_session.MonitoredSession(
      session_creator=session_creator) as sess:
    predictions = sess.run(endpoints.predicted_text,
                           feed_dict={images_placeholder: images_data})
  return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()] 
开发者ID:tensorflow,项目名称:models,代码行数:14,代码来源:demo_inference.py


注:本文中的common_flags.create_model方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。