当前位置: 首页>>代码示例>>Python>>正文


Python context.get_default_context函数代码示例

本文整理汇总了Python中tensorflow.python.eager.context.get_default_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_default_context函数的具体用法?Python get_default_context怎么用?Python get_default_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_default_context函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: decorator

  def decorator(self, **kwargs):
    """Finds existing Tensors, runs the test, checks for new Tensors."""

    def _is_tensor(obj):
      try:
        return (isinstance(obj, ops.Tensor) or
                isinstance(obj, variables.Variable))
      except ReferenceError:
        # If the object no longer exists, we don't care about it.
        return False

    tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj))
    outside_container_prefix = ops.get_default_graph()._container_prefix
    with IsolateTest():
      # Run the test in a new graph so that collections get cleared when it's
      # done, but inherit the container prefix so that we can print the values
      # of variables which get leaked when executing eagerly.
      ops.get_default_graph()._container_prefix = outside_container_prefix
      f(self, **kwargs)
    # Make an effort to clear caches, which would otherwise look like leaked
    # Tensors.
    backprop._last_zero = [None]
    backprop._shape_dtype = [None, None]
    context.get_default_context().scalar_cache().clear()
    gc.collect()
    tensors_after = [
        obj for obj in gc.get_objects()
        if _is_tensor(obj) and id(obj) not in tensors_before
    ]
    if tensors_after:
      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
          len(tensors_after),
          str(tensors_after),
      )))
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:34,代码来源:test_util.py

示例2: _compare

 def _compare(self, dims, val, np_ans, use_gpu):
   ctx = context.get_default_context()
   device = "GPU:0" if (use_gpu and ctx.num_gpus()) else "CPU:0"
   with ops.device(device):
     tf_ans = array_ops.fill(dims, val, name="fill")
     out = tf_ans.numpy()
   self.assertAllClose(np_ans, out)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:7,代码来源:constant_op_eager_test.py

示例3: execute

def execute(op_name, num_outputs, inputs, attrs=None, name=None):
  """Execute a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    num_outputs: The number of outputs of the operation to fetch.
                 (Explicitly provided instead of being inferred for performance
                 reasons).
    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
      a value which can be passed to the Tensor constructor to create one.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    name: Customized name for the operation.

  Returns:
    None if there are no outputs, a single Tensor object if there is one output
    and a list of Tensor objects if there are multiple outputs.

  Raises:
    An exception on error.
  """
  ctx = context.get_default_context()
  # TODO(apassos) move this to convert_to_tensor
  inputs = [ag_core.getval(x) for x in inputs]
  # pylint: disable=protected-access
  input_handles = [c._handle for c in inputs]
  device_name = ctx.device_name
  try:
    outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
                                            str(op_name), input_handles, attrs,
                                            num_outputs)
    # pylint: enable=protected-access
  except core._NotOkStatusException as e:  # pylint: disable=protected-access
    if name is not None:
      message = e.message + " name: " + name
    else:
      message = e.message
    raise core._status_to_exception(e.code, message)  # pylint: disable=protected-access
  # pylint: enable=protected-access

  tensors = [tensor._tensor_from_handle(x) for x in outh]  # pylint: disable=protected-access
  # TODO(alive, cais): Use the execution callback mechanism.
  if core.active_trace() is not None:
    trace_name = name if name else op_name
    for t in tensors:
      # pylint: disable=protected-access
      core.active_trace().record_tensor(trace_name,
                                        ops.tensor_id(t),
                                        t._device_name(),
                                        t.shape.num_elements())
      # pylint: enable=protected-access

  # TODO(cais): Optimize this, perhaps by replacing this execute function with
  # a different one when there are execution callback(s).
  for callback in ctx.post_execution_callbacks:
    callback(op_name, name, attrs, inputs, tensors)

  return tensors
开发者ID:keveman,项目名称:tensorflow,代码行数:59,代码来源:execute.py

示例4: enable_tracing

def enable_tracing():
  """Enables tracing of execution and memory usage.

  WARNING: tracing is not thread-safe.
  """
  global _active_trace
  _active_trace = memory_trace.MemoryTrace(
      len(context.get_default_context().devices()))
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:8,代码来源:core.py

示例5: register_function_def

def register_function_def(fdef):
  fdef_string = fdef.SerializeToString()
  with errors.raise_exception_on_not_ok_status() as status:
    pywrap_tensorflow.TFE_ContextAddFunctionDef(
        context.get_default_context()._handle,  # pylint: disable=protected-access
        fdef_string,
        len(fdef_string),
        status)
开发者ID:chdinh,项目名称:tensorflow,代码行数:8,代码来源:function.py

示例6: add_execution_callback

def add_execution_callback(callback):
  """Add an execution callback to the default eager context.

  An execution callback is invoked immediately after an eager operation or
  function has finished execution, providing access to the op's type, name
  input and output tensors. Multiple execution callbacks can be added, in
  which case the callbacks will be invoked in the order in which they are
  added. To clear all execution callbacks that have been added, use
  `clear_execution_callbacks()`.

  Example:
  ```python
  def print_even_callback(op_type, op_name, attrs, inputs, outputs):
    # A callback that prints only the even output values.
    if outputs[0].numpy() % 2 == 0:
      print("Even output from %s: %s" % (op_name or op_type,  outputs))
  tfe.add_execution_callback(print_even_callback)

  x = tf.pow(2.0, 3.0) - 3.0
  y = tf.multiply(x, tf.add(1.0, 5.0))
  # When the line above is run, you will see all intermediate outputs that are
  # even numbers printed to the console.

  tfe.clear_execution_callbacks()
  ```

  Args:
    callback: a callable of the signature
      `f(op_type, op_name, attrs, inputs, outputs)`.
      `op_type` is the type of the operation that was just executed (e.g.,
        `MatMul`).
      `op_name` is the name of the operation that was just executed. This
        name is set by the client who created the operation and can be `None` if
        it is unset.
      `attrs` contains the attributes of the operation as a `tuple` of
        alternating attribute name and attribute value.
      `inputs` is the `list` of input `Tensor`(s) to the op.
      `outputs` is the `list` of output `Tensor`(s) from the op.
       Return value(s) from the callback are ignored.
  """
  execute.execute = execute.execute_with_callbacks
  context.get_default_context().add_post_execution_callback(callback)
开发者ID:Eagle732,项目名称:tensorflow,代码行数:42,代码来源:execution_callbacks.py

示例7: testDefaultContext

 def testDefaultContext(self):
   orig = context.get_default_context()
   self.assertIs(context.get_default_context(), orig)
   c0 = context.Context()
   self.assertIs(context.get_default_context(), orig)
   context_manager_0 = c0.as_default()
   self.assertIs(context.get_default_context(), orig)
   with context_manager_0 as c0:
     self.assertIs(context.get_default_context(), c0)
     with context.Context().as_default() as c1:
       self.assertIs(context.get_default_context(), c1)
     self.assertIs(context.get_default_context(), c0)
   self.assertIs(context.get_default_context(), orig)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:13,代码来源:core_test.py

示例8: run_fn

 def run_fn(ctx1):
   ctx2 = context.get_default_context()
   # Default context created in different threads are different.
   self.assertIsNot(ctx1, ctx2)
   # Check that default values of the context created in a different thread
   # are set correctly.
   self.assertFalse(ctx2.in_graph_mode())
   self.assertTrue(ctx2.in_eager_mode())
   self.assertEqual('', ctx2.scope_name)
   self.assertEqual(-1, ctx2._device_index)  # pylint: disable=protected-access
   self.assertFalse(ctx2.recording_summaries)
   self.assertIsNone(ctx2.summary_writer_resource)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:12,代码来源:core_test.py

示例9: testVariableEager

  def testVariableEager(self):
    with context.eager_mode():
      init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32)
      constraint = lambda x: x
      with ops.name_scope("foo"):
        v = resource_variable_ops.ResourceVariable(
            name="var7",
            initial_value=init,
            caching_device="cpu:0",
            constraint=constraint)
      # Test properties
      self.assertEqual(dtypes.int32, v.dtype)
      self.assertEqual("foo/var7:0", v.name)
      self.assertAllEqual([10, 20, 35], v.shape.as_list())
      self.assertEqual(context.get_default_context().device_name, v.device)
      self.assertTrue(isinstance(v.handle, ops.EagerTensor))
      self.assertEqual(constraint, v.constraint)
      self.assertAllEqual(init.numpy(), v.read_value().numpy())
      self.assertAllEqual(init.numpy(), v.value().numpy())

      # Callable init.
      callable_init = lambda: init * 2
      v2 = resource_variable_ops.ResourceVariable(
          initial_value=callable_init, name="var7")
      self.assertEqual("var7:0", v2.name)
      self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy())

      # Test assign_add.
      new_v2_val = v2.assign_add(v.read_value())
      self.assertAllEqual(v.read_value().numpy() * 3, new_v2_val.numpy())

      # Test assign_sub.
      new_v2_val = v2.assign_sub(v.read_value())
      self.assertAllEqual(v.read_value().numpy() * 2, new_v2_val.numpy())

      # Test assign.
      v2.assign(v.read_value())
      self.assertAllEqual(v.read_value().numpy(), v2.read_value().numpy())

      # Test load
      v2.load(2 * v.read_value())
      self.assertAllEqual(2 * v.read_value().numpy(), v2.read_value().numpy())

      # Test convert_to_tensor
      t = ops.convert_to_tensor(v)
      self.assertAllEqual(t.numpy(), v.read_value().numpy())

      # Test operations
      self.assertAllEqual((v * 2).numpy(), (v + v).numpy())
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:49,代码来源:resource_variable_ops_test.py

示例10: add_execution_callback

def add_execution_callback(callback):
  """Add an execution callback to the default eager context.

  An execution callback is invoked immediately after an eager operation or
  function has finished execution, providing access to the op's type, name
  input and output tensors. Multiple execution callbacks can be added, in
  which case the callbacks will be invoked in the order in which they are
  added.

  Args:
    callback: a callable of the signature
      `f(op_type, op_name, attrs, inputs, outputs)`.
      `op_type` is the type of the operation that was just executed (e.g.,
        `MatMul`).
      `op_name` is the name of the operation that has was just executed. This
        name is set by the client who created the operation and can be `None` if
        it is unset.
      `attrs` contains the attributes of the operation as a `tuple` of
        alternating attribute name and attribute value.
      `inputs` is the `list` of input `Tensor`(s) to the op.
      `outputs` is the `list` of output `Tensor`(s) from the op.
       Return value(s) from the callback are ignored.
  """
  context.get_default_context().add_post_execution_callback(callback)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:24,代码来源:execution_callbacks.py

示例11: testEagerIdentity

  def testEagerIdentity(self):
    with context.eager_mode():
      ctx = context.get_default_context()
      if not ctx.num_gpus():
        self.skipTest("No GPUs found")

      def _test(x, y, device):
        self.assertAllEqual(x.numpy(), y.numpy())
        self.assertTrue(device in y.device.lower())

      with ops.device("gpu:0"):
        a = constant_op.constant([[2], [3]], dtype=dtypes.float32)
      with ops.device("gpu:0"):
        b = array_ops.identity(a)
        _test(a, b, "gpu")
      with ops.device("cpu:0"):
        c = array_ops.identity(b)
        _test(b, c, "cpu")
      with ops.device("cpu:0"):
        d = array_ops.identity(c)
        _test(c, d, "cpu")
      with ops.device("gpu:0"):
        e = array_ops.identity(d)
        _test(d, e, "gpu")
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:24,代码来源:array_ops_test.py

示例12: _init_from_args


#.........这里部分代码省略.........
    self._trainable = trainable
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    self._save_slice_info = None
    # Store the graph key so optimizers know how to only retrieve variables from
    # this graph.
    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    with ops.init_scope():
      self._in_graph_mode = context.in_graph_mode()
      with ops.name_scope(name, "Variable", []
                          if init_from_fn else [initial_value]) as name:
        # pylint: disable=protected-access
        handle_name = ops._name_from_scope_name(name)
        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          if self._in_graph_mode:
            attr = attr_value_pb2.AttrValue(
                list=attr_value_pb2.AttrValue.ListValue(
                    s=[compat.as_bytes("loc:@%s" % handle_name)]))
            with ops.get_default_graph()._attr_scope({"_class": attr}):
              with ops.name_scope("Initializer"), ops.device(None):
                initial_value = ops.convert_to_tensor(
                    initial_value(), name="initial_value", dtype=dtype)
              self._handle = _eager_safe_variable_handle(
                  shape=initial_value.get_shape(),
                  dtype=initial_value.dtype.base_dtype,
                  shared_name=handle_name,
                  name=name,
                  graph_mode=self._in_graph_mode)
              self._handle_device = (
                  self._handle.device if self._in_graph_mode else
                  context.get_default_context().device_name)
              self._shape = initial_value.get_shape()
          else:
            initial_value = initial_value()
            with ops.name_scope("Initializer"):
              initial_value = ops.convert_to_tensor(
                  initial_value, name="initial_value", dtype=dtype)
            self._handle = _eager_safe_variable_handle(
                shape=initial_value.get_shape(),
                dtype=initial_value.dtype.base_dtype,
                shared_name=handle_name,
                name=name,
                graph_mode=False)
            self._handle_device = (
                self._handle.device if self._in_graph_mode else
                context.get_default_context().device_name)
            self._shape = initial_value.get_shape()
        # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          with ops.name_scope("Initializer"):
            initial_value = ops.convert_to_tensor(
                initial_value, name="initial_value", dtype=dtype)
          # pylint: disable=protected-access
          if (self._in_graph_mode and initial_value is not None and
              initial_value.op._get_control_flow_context() is not None):
            raise ValueError(
                "Initializer for variable %s is from inside a control-flow "
                "construct, such as a loop or conditional. When creating a "
                "variable inside a loop or conditional, use a lambda as the "
                "initializer." % name)
          # pylint: enable=protected-access
开发者ID:keithc61,项目名称:tensorflow,代码行数:67,代码来源:resource_variable_ops.py

示例13: seterr

def seterr(inf_or_nan=None):
  """Set how abnormal conditions are handled by the default eager context.

  Example:
  ```python
  tfe.seterr(inf_or_nan="raise")
  a = tf.constant(10.0)
  b = tf.constant(0.0)
  try:
    c = a / b  # <-- Raises InfOrNanError.
  except Exception as e:
    print("Caught Exception: %s" % e)

  tfe.seterr(inf_or_nan="ignore")
  c = a / b  # <-- Does NOT raise exception anymore.
  ```

  Args:
    inf_or_nan: Set action for infinity (`inf`) and NaN (`nan`) values.
      Possible values: `{"ignore", "print", "raise", "warn"}`.
      `"ignore"`: take no action when `inf` values appear.
      `"print"`: print a warning to `stdout`.
      `"raise"`: raise an `InfOrNanError`.
      `"warn"`: print a warning using `tf.logging.warn`.
      A value of `None` leads to no change in the action of the condition.

  Returns:
    A dictionary of old actions.

  Raises:
    ValueError: If the value of any keyword arguments is invalid.
  """
  if inf_or_nan not in _VALID_CALLBACK_ACTIONS:
    raise ValueError(
        "Invalid action value for inf_or_nan: %s. "
        "Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))

  old_settings = {"inf_or_nan": "ignore"}
  default_context = context.get_default_context()

  carryover_callbacks = []
  for callback in default_context.post_execution_callbacks:
    # Check whether the callback is inf_nan_callback or a partial object of
    # inf_nan_callback.
    if (callback == inf_nan_callback or
        isinstance(callback, functools.partial) and
        callback.func == inf_nan_callback):
      if callback == inf_nan_callback:
        old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
      else:
        old_settings["inf_or_nan"] = callback.keywords.get(
            "action", _DEFAULT_CALLBACK_ACTION)
    elif inf_or_nan is not None:
      carryover_callbacks.append(callback)

  if inf_or_nan is not None:
    default_context.clear_post_execution_callbacks()
    for callback in carryover_callbacks:
      default_context.add_post_execution_callback(callback)
    if inf_or_nan != "ignore":
      default_context.add_post_execution_callback(
          functools.partial(inf_nan_callback, action=inf_or_nan))

  return old_settings
开发者ID:allanbian1017,项目名称:tensorflow,代码行数:64,代码来源:execution_callbacks.py

示例14: clear_execution_callbacks

def clear_execution_callbacks():
  """Clear all execution callbacks from the default eager context."""
  context.get_default_context().clear_post_execution_callbacks()
开发者ID:allanbian1017,项目名称:tensorflow,代码行数:3,代码来源:execution_callbacks.py

示例15: inf_nan_callback

def inf_nan_callback(op_type,
                     op_name,
                     attrs,
                     inputs,
                     outputs,
                     check_inf=True,
                     check_nan=True,
                     action=_DEFAULT_CALLBACK_ACTION):
  """An execution callback that checks for `inf`s and `nan`s in output tensors.

  This callback can be used with `tfe.add_execute_callback` to check for invalid
  numeric values. E.g.,
  ```python
  tfe.add_execute_callback(tfe.inf_nan_callback)
  ```

  Args:
    op_type: Name of the TFE operation type (e.g., `MatMul`).
    op_name: Name of the TFE operation. This name is set by client and can be
      `None` if it unset.
    attrs: Attributes of the TFE operation, as a tuple of alternating attribute
      names and attribute values.
    inputs: The `list` of input tensors to the operation, currently unused by
      this callback.
    outputs: The `list` of output tensors from the operation, checked by this
      callback for `inf` and `nan` values.
    check_inf: (`bool`) Whether this callback should check for `inf` values in
      the output tensor values.
    check_nan: (`bool`) Whether this callback should check for `nan` values in
      the output tensor values.
    action: (`str`) Action to be taken by the callback when `inf` or `nan`
      values are detected. Possible values {"raise", "warn", "print"}
      `"raise"`: Raise a `InfOrNanError`.
      `"warn"`: Log a warning using `tf.logging.warn`.
      `"print"`: Print a message to `sys.stdout`.

  Raises:
    InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
      `action` is `"raise"`.
    ValueError: iff the value of `action` is invalid.
  """
  del attrs, inputs  # Not used.

  ctx = context.get_default_context()

  for index, output in enumerate(outputs):
    if not output.dtype.is_numpy_compatible:
      continue

    numpy_dtype = output.dtype.as_numpy_dtype
    if (np.issubdtype(numpy_dtype, np.float) or
        np.issubdtype(numpy_dtype, np.complex) or
        np.issubdtype(numpy_dtype, np.integer)):
      try:
        check_numerics_op_attrs = (
            "message", "Eager-mode inf/nan check",
            "T", outputs[0].dtype.as_datatype_enum)
        # TODO(cais): Consider moving this into execute.py.
        # pylint: disable=protected-access
        pywrap_tensorflow.TFE_Py_Execute(
            ctx._handle, output.device, "CheckNumerics", [output._handle],
            check_numerics_op_attrs, 1)
        # pylint: enable=protected-access
      except core._NotOkStatusException:  # pylint: disable=protected-access
        value = output.numpy()
        inf_detected = np.any(np.isinf(value)) and check_inf
        nan_detected = np.any(np.isnan(value)) and check_nan
        if not inf_detected and not nan_detected:
          continue

        error = InfOrNanError(op_type, op_name, index, len(outputs), value)
        if action == "print":
          print("Warning: %s" % str(error))
        elif action == "warn":
          logging.warn(str(error))
        elif action == "raise":
          raise error
        else:
          raise ValueError(
              "Invalid action for inf_nan_callback: %s. Valid actions are: "
              "{print | warn | raise}" % action)
开发者ID:allanbian1017,项目名称:tensorflow,代码行数:81,代码来源:execution_callbacks.py


注:本文中的tensorflow.python.eager.context.get_default_context函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。