當前位置: 首頁>>代碼示例>>Python>>正文


Python v1.RunMetadata方法代碼示例

本文整理匯總了Python中tensorflow.compat.v1.RunMetadata方法的典型用法代碼示例。如果您正苦於以下問題:Python v1.RunMetadata方法的具體用法?Python v1.RunMetadata怎麽用?Python v1.RunMetadata使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.RunMetadata方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: save_checkpoint_chrome_trace

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import RunMetadata [as 別名]
def save_checkpoint_chrome_trace(dataset: str, model_name: str, log_base: PathLike, batch_size: int = 32):
    def trace_solver_solution(save_path: PathLike, train_ds, solver):
        import tensorflow.compat.v1 as tf1
        from tensorflow.python.client import timeline

        data_iter = train_ds.__iter__()
        data_list = [x.numpy() for x in data_iter.next()]
        with tf1.Session() as sess:
            sqrtn_fn, *_ = _build_model_via_solver(dataset, model_name, train_ds.element_spec, solver)
            out = sqrtn_fn(*[tf1.convert_to_tensor(x) for x in data_list])

            run_meta = tf1.RunMetadata()
            sess.run(tf1.global_variables_initializer())
            sess.run(out, options=tf1.RunOptions(trace_level=tf1.RunOptions.FULL_TRACE), run_metadata=run_meta)
            t1 = timeline.Timeline(run_meta.step_stats)
            lctf = t1.generate_chrome_trace_format()

        with Path(save_path).open("w") as f:
            f.write(lctf)

    log_base = Path(log_base)
    log_base.mkdir(parents=True, exist_ok=True)
    train_ds, test_ds = get_data(dataset, batch_size=batch_size)
    trace_solver_solution(log_base / "check_all.json", train_ds, solve_checkpoint_all)
    trace_solver_solution(log_base / "check_sqrtn_noap.json", train_ds, solve_chen_sqrtn_noap) 
開發者ID:parasj,項目名稱:checkmate,代碼行數:27,代碼來源:test_tf2_execution.py

示例2: benchmark

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import RunMetadata [as 別名]
def benchmark(self, image_arrays, trace_filename=None):
    """Benchmark inference latency/throughput.

    Args:
      image_arrays: a list of images in numpy array format.
      trace_filename: If None, specify the filename for saving trace.
    """
    if not self.sess:
      self.build()

    # init session
    self.sess.run(
        self.signitures['prediction'],
        feed_dict={self.signitures['image_arrays']: image_arrays})

    start = time.perf_counter()
    for _ in range(10):
      self.sess.run(
          self.signitures['prediction'],
          feed_dict={self.signitures['image_arrays']: image_arrays})
    end = time.perf_counter()
    inference_time = (end - start) / 10

    print('Per batch inference time: ', inference_time)
    print('FPS: ', self.batch_size / inference_time)

    if trace_filename:
      run_options = tf.RunOptions()
      run_options.trace_level = tf.RunOptions.FULL_TRACE
      run_metadata = tf.RunMetadata()
      self.sess.run(
          self.signitures['prediction'],
          feed_dict={self.signitures['image_arrays']: image_arrays},
          options=run_options,
          run_metadata=run_metadata)
      with tf.io.gfile.GFile(trace_filename, 'w') as trace_file:
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        trace_file.write(trace.generate_chrome_trace_format(show_memory=True)) 
開發者ID:PINTO0309,項目名稱:PINTO_model_zoo,代碼行數:40,代碼來源:inference.py

示例3: get_flops

# 需要導入模塊: from tensorflow.compat import v1 [as 別名]
# 或者: from tensorflow.compat.v1 import RunMetadata [as 別名]
def get_flops(model):
    run_meta = tf.RunMetadata()
    graph = tf.get_default_graph()

    # We use the Keras session graph in the call to the profiler.
    opts = tf.profiler.ProfileOptionBuilder.float_operation()
    flops = tf.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)

    opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()
    params = tf.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)

    print('Total FLOPs: {}m float_ops'.format(flops.total_float_ops/1e6))
    print('Total PARAMs: {}m'.format(params.total_parameters/1e6)) 
開發者ID:david8862,項目名稱:keras-YOLOv3-model-set,代碼行數:15,代碼來源:model_statistics.py


注:本文中的tensorflow.compat.v1.RunMetadata方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。