当前位置: 首页>>代码示例>>Python>>正文


Python debug_utils.watch_graph函数代码示例

本文整理汇总了Python中tensorflow.python.debug.lib.debug_utils.watch_graph函数的典型用法代码示例。如果您正苦于以下问题:Python watch_graph函数的具体用法?Python watch_graph怎么用?Python watch_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了watch_graph函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: createAndRunGraphWithWhileLoop

  def createAndRunGraphWithWhileLoop(self):
    """Create and run a TensorFlow Graph with a while loop to generate dumps."""

    self.dump_root = self.get_temp_dir()
    self.curr_file_path = os.path.abspath(
        tf_inspect.getfile(tf_inspect.currentframe()))

    # Run a simple TF graph to generate some debug dumps that can be used in
    # source annotation.
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      self.traceback_first_line = line_number_above()

      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
      run_metadata = config_pb2.RunMetadata()
      sess.run(loop, options=run_options, run_metadata=run_metadata)

      self.dump = debug_data.DebugDumpDir(
          self.dump_root, partition_graphs=run_metadata.partition_graphs)
      self.dump.set_python_graph(sess.graph)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:27,代码来源:source_utils_test.py

示例2: before_run

  def before_run(self, run_context):
    if not self._wrapper_initialized:
      dumping_wrapper.DumpingDebugWrapperSession.__init__(
          self,
          run_context.session,
          self._session_root,
          watch_fn=self._watch_fn,
          log_usage=self._log_usage)
      self._wrapper_initialized = True

    self._run_call_count += 1

    (debug_urls, debug_ops, node_name_regex_whitelist,
     op_type_regex_whitelist) = self._prepare_run_watch_config(
         run_context.original_args.fetches, run_context.original_args.feed_dict)
    run_options = config_pb2.RunOptions()
    debug_utils.watch_graph(
        run_options,
        run_context.session.graph,
        debug_urls=debug_urls,
        debug_ops=debug_ops,
        node_name_regex_whitelist=node_name_regex_whitelist,
        op_type_regex_whitelist=op_type_regex_whitelist)

    run_args = session_run_hook.SessionRunArgs(
        None, feed_dict=None, options=run_options)
    return run_args
开发者ID:brainwy12,项目名称:tensorflow,代码行数:27,代码来源:hooks.py

示例3: testWatchGraph_allNodes

  def testWatchGraph_allNodes(self):
    debug_utils.watch_graph(
        self._run_options,
        self._graph,
        debug_ops=["DebugIdentity", "DebugNanCount"],
        debug_urls="file:///tmp/tfdbg_1")

    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
    self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))

    # Verify that each of the nodes in the graph with output tensors in the
    # graph have debug tensor watch.
    node_names = self._verify_watches(debug_watch_opts, 0,
                                      ["DebugIdentity", "DebugNanCount"],
                                      ["file:///tmp/tfdbg_1"])

    # Verify the node names.
    self.assertTrue("a1_init" in node_names)
    self.assertTrue("a1" in node_names)
    self.assertTrue("a1/Assign" in node_names)
    self.assertTrue("a1/read" in node_names)

    self.assertTrue("b_init" in node_names)
    self.assertTrue("b" in node_names)
    self.assertTrue("b/Assign" in node_names)
    self.assertTrue("b/read" in node_names)

    self.assertTrue("c" in node_names)
    self.assertTrue("p1" in node_names)
    self.assertTrue("s" in node_names)
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:debug_utils_test.py

示例4: before_run

  def before_run(self, run_context):
    if not self._session_wrapper:
      self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
          run_context.session,
          self._session_root,
          watch_fn=self._watch_fn,
          thread_name_filter=self._thread_name_filter,
          log_usage=self._log_usage)

    self._session_wrapper.increment_run_call_count()

    # pylint: disable=protected-access
    debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
        run_context.original_args.fetches, run_context.original_args.feed_dict)
    # pylint: enable=protected-access
    run_options = config_pb2.RunOptions()
    debug_utils.watch_graph(
        run_options,
        run_context.session.graph,
        debug_urls=debug_urls,
        debug_ops=watch_options.debug_ops,
        node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
        op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
        tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
        tolerate_debug_op_creation_failures=(
            watch_options.tolerate_debug_op_creation_failures))

    run_args = session_run_hook.SessionRunArgs(
        None, feed_dict=None, options=run_options)
    return run_args
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:30,代码来源:hooks.py

示例5: before_run

  def before_run(self, run_context):
    if not self._wrapper_initialized:
      # TODO(cais): Make this hook have a DumpingDebugWrapperSession property
      # instead of subclassing DumpingDebugWrapperSession.
      dumping_wrapper.DumpingDebugWrapperSession.__init__(
          self,
          run_context.session,
          self._session_root,
          watch_fn=self._watch_fn,
          thread_name_filter=self._thread_name_filter,
          log_usage=self._log_usage)
      self._wrapper_initialized = True

    self._run_call_count += 1

    debug_urls, watch_options = self._prepare_run_watch_config(
        run_context.original_args.fetches, run_context.original_args.feed_dict)
    run_options = config_pb2.RunOptions()
    debug_utils.watch_graph(
        run_options,
        run_context.session.graph,
        debug_urls=debug_urls,
        debug_ops=watch_options.debug_ops,
        node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
        op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
        tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
        tolerate_debug_op_creation_failures=(
            watch_options.tolerate_debug_op_creation_failures))

    run_args = session_run_hook.SessionRunArgs(
        None, feed_dict=None, options=run_options)
    return run_args
开发者ID:finardi,项目名称:tensorflow,代码行数:32,代码来源:hooks.py

示例6: testToggleBreakpointsWorks

  def testToggleBreakpointsWorks(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[self._debug_server_url_1])

      for i in xrange(4):
        self._server_1.clear_data()

        if i in (0, 2):
          # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)
          self._server_1.request_watch(
              "delta_2", 0, "DebugIdentity", breakpoint=True)
        else:
          # Disable the breakpoint in runs 1 and 3.
          self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")

        output = sess.run([inc_v_1, inc_v_2],
                          options=run_options, run_metadata=run_metadata)
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        if i in (0, 2):
          # During runs 0 and 2, the server should have received the published
          # debug tensor delta:0:DebugIdentity. The breakpoint should have been
          # unblocked by EventReply reponses from the server.
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
          # After the runs, the server should have properly registered the
          # breakpoints due to the request_unwatch calls.
          self.assertSetEqual({("delta_1", 0, "DebugIdentity"),
                               ("delta_2", 0, "DebugIdentity")},
                              self._server_1.breakpoints)
        else:
          # After the end of runs 1 and 3, the server has received the requests
          # to disable the breakpoint at delta:0:DebugIdentity.
          self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:57,代码来源:session_debug_grpc_test.py

示例7: _decorate_options_for_debug

  def _decorate_options_for_debug(self, options, graph):
    """Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.

    Args:
      options: (config_pb2.RunOptions) The RunOptions instance to be modified.
      graph: A TensorFlow Graph object.
    """

    debug_utils.watch_graph(
        options, graph, debug_urls=self._get_run_debug_urls())
    options.output_partition_graphs = True
开发者ID:aravindvcyber,项目名称:tensorflow,代码行数:11,代码来源:hooks.py

示例8: testWatchGraph_tensorDTypeWhitelist

  def testWatchGraph_tensorDTypeWhitelist(self):
    debug_utils.watch_graph(
        self._run_options,
        self._graph,
        debug_urls="file:///tmp/tfdbg_1",
        tensor_dtype_regex_whitelist=".*_ref")

    node_names = self._verify_watches(
        self._run_options.debug_options.debug_tensor_watch_opts, 0,
        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
    self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
开发者ID:aeverall,项目名称:tensorflow,代码行数:11,代码来源:debug_utils_test.py

示例9: testWatchGraph_opTypeWhitelist

  def testWatchGraph_opTypeWhitelist(self):
    debug_utils.watch_graph(
        self._run_options,
        self._graph,
        debug_urls="file:///tmp/tfdbg_1",
        op_type_regex_whitelist="(Variable|MatMul)")

    node_names = self._verify_watches(
        self._run_options.debug_options.debug_tensor_watch_opts, 0,
        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
    self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
开发者ID:aeverall,项目名称:tensorflow,代码行数:11,代码来源:debug_utils_test.py

示例10: testWatchGraph_nodeNameAndOpTypeWhitelists

  def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
    debug_utils.watch_graph(
        self._run_options,
        self._graph,
        debug_urls="file:///tmp/tfdbg_1",
        node_name_regex_whitelist="([a-z]+1$)",
        op_type_regex_whitelist="(MatMul)")

    node_names = self._verify_watches(
        self._run_options.debug_options.debug_tensor_watch_opts, 0,
        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
    self.assertEqual(["p1"], node_names)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:debug_utils_test.py

示例11: _decorate_options_for_debug

 def _decorate_options_for_debug(self, options, graph, watch_options):
   """Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging."""
   debug_utils.watch_graph(
       options,
       graph,
       debug_urls=self._get_run_debug_urls(),
       node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
       op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
       tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
       tolerate_debug_op_creation_failures=(
           watch_options.tolerate_debug_op_creation_failures))
   options.output_partition_graphs = True
开发者ID:finardi,项目名称:tensorflow,代码行数:12,代码来源:hooks.py

示例12: testWatchGraph_nodeNameWhitelist

  def testWatchGraph_nodeNameWhitelist(self):
    debug_utils.watch_graph(
        self._run_options,
        self._graph,
        debug_urls="file:///tmp/tfdbg_1",
        node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")

    node_names = self._verify_watches(
        self._run_options.debug_options.debug_tensor_watch_opts, 0,
        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
    self.assertEqual(
        sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
        sorted(node_names))
开发者ID:aeverall,项目名称:tensorflow,代码行数:13,代码来源:debug_utils_test.py

示例13: testGradientsValuesFromDumpWorks

  def testGradientsValuesFromDumpWorks(self):
    y = math_ops.add(self.w, -1.0, name="y")
    z = math_ops.square(y, name="z")

    grad_debugger = debug_gradients.GradientsDebugger()
    with grad_debugger.watch_gradients_by_tensors(
        self.sess.graph, [self.w, self.u, y]):
      train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z)

    self.sess.run(variables.global_variables_initializer())

    run_options = config_pb2.RunOptions(output_partition_graphs=True)
    dump_dir = tempfile.mkdtemp()
    debug_url = "file://" + dump_dir
    debug_utils.watch_graph(
        run_options,
        self.sess.graph,
        debug_urls=debug_url)
    run_metadata = config_pb2.RunMetadata()
    self.assertAllClose(2.0, self.sess.run(self.u))
    self.sess.run(train_op, options=run_options, run_metadata=run_metadata)
    self.assertAllClose(-1.0, self.sess.run(self.u))

    dump = debug_data.DebugDumpDir(
        dump_dir, partition_graphs=run_metadata.partition_graphs)
    dump.set_python_graph(self.sess.graph)

    y_grad_values = debug_gradients.gradient_values_from_dump(
        grad_debugger, y, dump)
    self.assertEqual(1, len(y_grad_values))
    self.assertAllClose(10.0, y_grad_values[0])

    w_grad_values = debug_gradients.gradient_values_from_dump(
        grad_debugger, self.w, dump)
    self.assertEqual(1, len(w_grad_values))
    self.assertAllClose(10.0, w_grad_values[0])

    u_grad_values = debug_gradients.gradient_values_from_dump(
        grad_debugger, self.u, dump)
    self.assertEqual(1, len(u_grad_values))
    self.assertAllClose(30.0, u_grad_values[0])

    with self.assertRaisesRegexp(
        LookupError,
        r"This GradientsDebugger has not received any gradient tensor for "
        r"x-tensor v:0"):
      debug_gradients.gradient_values_from_dump(grad_debugger, self.v, dump)

    # Cleanup.
    shutil.rmtree(dump_dir)
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:50,代码来源:debug_gradients_test.py

示例14: testToggleBreakpointWorks

  def testToggleBreakpointWorks(self):
    with session.Session(config=no_rewrite_session_config()) as sess:
      v = variables.Variable(50.0, name="v")
      delta = constant_op.constant(5.0, name="delta")
      inc_v = state_ops.assign_add(v, delta, name="inc_v")

      sess.run(v.initializer)

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[self._debug_server_url_1])

      for i in xrange(4):
        self._server_1.clear_data()

        # N.B.: These requests will be fulfilled not in this debugged
        # Session.run() invocation, but in the next one.
        if i in (0, 2):
          # Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2.
          self._server_1.request_watch(
              "delta", 0, "DebugIdentity", breakpoint=True)
        else:
          # Disable the breakpoint in runs 1 and 3.
          self._server_1.request_unwatch("delta", 0, "DebugIdentity")

        output = sess.run(inc_v, options=run_options, run_metadata=run_metadata)
        self.assertAllClose(50.0 + 5.0 * (i + 1), output)

        if i in (0, 2):
          # After the end of runs 0 and 2, the server has received the requests
          # to enable the breakpoint at delta:0:DebugIdentity. So the server
          # should keep track of the correct breakpoints.
          self.assertSetEqual({("delta", 0, "DebugIdentity")},
                              self._server_1.breakpoints)
        else:
          # During runs 1 and 3, the server should have received the published
          # debug tensor delta:0:DebugIdentity. The breakpoint should have been
          # unblocked by EventReply reponses from the server.
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
          # After the runs, the server should have properly removed the
          # breakpoints due to the request_unwatch calls.
          self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:chdinh,项目名称:tensorflow,代码行数:48,代码来源:session_debug_grpc_test.py

示例15: testToggleWatchesOnCoreMetadata

  def testToggleWatchesOnCoreMetadata(self):
    (_, debug_server_url, _, server_thread,
     server) = grpc_debug_test_server.start_server_on_separate_thread(
         dump_to_filesystem=False,
         toggle_watch_on_core_metadata=[("toggled_1", 0, "DebugIdentity"),
                                        ("toggled_2", 0, "DebugIdentity")])
    self._servers_and_threads.append((server, server_thread))

    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_1")
      # These two nodes have names that match those in the
      # toggle_watch_on_core_metadata argument used when calling
      # start_server_on_separate_thread().
      toggled_1 = constant_op.constant(5.0, name="toggled_1")
      toggled_2 = constant_op.constant(-5.0, name="toggled_2")
      inc_v_1 = state_ops.assign_add(v_1, toggled_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, toggled_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[debug_server_url])

      for i in xrange(4):
        server.clear_data()

        sess.run([inc_v_1, inc_v_2],
                 options=run_options, run_metadata=run_metadata)

        if i % 2 == 0:
          self.assertEqual(2, len(server.debug_tensor_values))
          self.assertAllClose(
              [5.0],
              server.debug_tensor_values["toggled_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              server.debug_tensor_values["toggled_2:0:DebugIdentity"])
        else:
          self.assertEqual(0, len(server.debug_tensor_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:46,代码来源:session_debug_grpc_test.py


注:本文中的tensorflow.python.debug.lib.debug_utils.watch_graph函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。