当前位置: 首页>>代码示例>>Python>>正文


Python eval_util.EvaluationMetrics方法代码示例

本文整理汇总了Python中eval_util.EvaluationMetrics方法的典型用法代码示例。如果您正苦于以下问题:Python eval_util.EvaluationMetrics方法的具体用法?Python eval_util.EvaluationMetrics怎么用?Python eval_util.EvaluationMetrics使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在eval_util的用法示例。


在下文中一共展示了eval_util.EvaluationMetrics方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: cal_gap

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def cal_gap(self):
        self.predict()
        evl_metrics = eval_util.EvaluationMetrics(4716, self.top_k)
        predictions_val = ensemble(self.pred_list, self.weights)

        print predictions_val.shape
        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                         self.labels_val, np.zeros(predictions_val.shape[0]))
        epoch_info_dict = evl_metrics.get()
        print(("[email protected]%d:" %self.top_k) + str(epoch_info_dict['gap'])) 
开发者ID:forwchen,项目名称:yt8m,代码行数:12,代码来源:valid.py

示例2: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    reader = readers.QuickDrawFeatureReader()

    model = find_class_by_name(FLAGS.model,
        [models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, 2)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:machine-learning-challenge,项目名称:mlc2017-online,代码行数:42,代码来源:eval.py

示例3: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    reader = readers.CatsVsDogsFeatureReader()

    model = find_class_by_name(FLAGS.model,
        [cvd_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    image_id_batch = tf.get_collection("id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, 2)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(image_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:machine-learning-challenge,项目名称:mlc2017-online,代码行数:43,代码来源:eval.py

示例4: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    reader = readers.CatsVsDogsFeatureReader()

    model = find_class_by_name(FLAGS.model,
        [cvd_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    image_id_batch = tf.get_collection("image_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, 2)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(image_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:SSUHan,项目名称:google_ml_challenge,代码行数:43,代码来源:eval.py

示例5: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    reader = readers.MnistReader()

    model = find_class_by_name(FLAGS.model,
        [mnist_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, 2)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:machine-learning-challenge,项目名称:tutorial_mnist,代码行数:42,代码来源:eval.py

示例6: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                              feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                   feature_sizes=feature_sizes)

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:52,代码来源:eval.py

示例7: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                              feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                   feature_sizes=feature_sizes)

    model = find_class_by_name(FLAGS.model,
        [labels_autoencoder])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:52,代码来源:eval_embedding.py

示例8: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)
    distill_names, distill_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.distill_names, FLAGS.distill_sizes)

    if FLAGS.frame_features:
      reader1 = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                              feature_sizes=feature_sizes)
    else:
      reader1 = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                   feature_sizes=feature_sizes)
    reader2 = readers.YT8MAggregatedFeatureReader(
        feature_names=distill_names, feature_sizes=distill_sizes)

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader1=reader1,
        reader2=reader2,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        distill_data_pattern=FLAGS.distill_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    unused_id_batch = tf.get_collection("unused_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader1.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, unused_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:59,代码来源:eval_distill.py

示例9: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                              feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                   feature_sizes=feature_sizes)

    if FLAGS.distill_data_pattern is not None:
      distill_reader = readers.YT8MAggregatedFeatureReader(feature_names=["predictions"],
                                                   feature_sizes=[4716])
    else:
      distill_reader = None

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
    transformer_class = find_class_by_name(FLAGS.feature_transformer, [feature_transform])

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        transformer_class=transformer_class,
        distill_reader=distill_reader,
        batch_size=FLAGS.batch_size)

    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:wangheda,项目名称:youtube-8m,代码行数:62,代码来源:eval.py

示例10: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(
        num_classes = FLAGS.truncated_num_classes,
        decode_zlib = FLAGS.decode_zlib,
        feature_names=feature_names, feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(
        num_classes = FLAGS.truncated_num_classes,
        decode_zlib = FLAGS.decode_zlib,
        feature_names=feature_names, feature_sizes=feature_sizes)

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models,
         xp_frame_level_models, xp_video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:mpekalski,项目名称:Y8M,代码行数:57,代码来源:eval.py

示例11: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(
               num_classes = FLAGS.truncated_num_classes,
               decode_zlib = FLAGS.decode_zlib,
               feature_names = feature_names,
               feature_sizes = feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(
               num_classes = FLAGS.truncated_num_classes,
               decode_zlib = FLAGS.decode_zlib,
               feature_names = feature_names,
               feature_sizes = feature_sizes)

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models,
         xp_frame_level_models, xp_video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break 
开发者ID:mpekalski,项目名称:Y8M,代码行数:59,代码来源:eval.py

示例12: evaluate

# 需要导入模块: import eval_util [as 别名]
# 或者: from eval_util import EvaluationMetrics [as 别名]
def evaluate():
  tf.set_random_seed(0)  # for reproducibility
  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        FLAGS.feature_names, FLAGS.feature_sizes)

    if FLAGS.frame_features:
      reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                              feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                   feature_sizes=feature_sizes)

    model = find_class_by_name(FLAGS.model,
        [frame_level_models, video_level_models])()
    label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    ### Newly
    # coarse_prediction_batch = tf.get_collection("coarse_predictions")[0]
    # coarse_label_batch = tf.get_collection("coarse_labels")[0]
    # coarse_loss = tf.get_collection("coarse_loss")[0]
    ###
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
                                             ### Newly
                                             # coarse_prediction_batch, coarse_label_batch, coarse_loss) 
                                             ###
      if FLAGS.run_once:
        break 
开发者ID:Tsingularity,项目名称:youtube-8m,代码行数:60,代码来源:eval.py


注:本文中的eval_util.EvaluationMetrics方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。