当前位置: 首页>>代码示例>>Python>>正文


Python tape.stop_recording函数代码示例

本文整理汇总了Python中tensorflow.python.eager.tape.stop_recording函数的典型用法代码示例。如果您正苦于以下问题:Python stop_recording函数的具体用法?Python stop_recording怎么用?Python stop_recording使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了stop_recording函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _create_variable

  def _create_variable(self, next_creator, *args, **kwargs):
    """Create a mirrored variable. See `DistributionStrategy.scope`."""
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)

    tower_local = kwargs.pop("tower_local_reduce_method", None)
    if tower_local is not None:
      kwargs["trainable"] = False

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
      index = {}
      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            kwargs["name"] = "%s/replica_%d" % (var0name, i)
            # Initialize replicas with the same value:
            if context.executing_eagerly():
              initial_value = index[devices[0]].value()
            else:
              initial_value = index[devices[0]].initial_value
            kwargs["initial_value"] = array_ops.identity(initial_value)
          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)
          assert not isinstance(v, values.DistributedVariable)
          index[d] = v

      if tower_local is None:
        result = values.MirroredVariable(index, index[devices[0]])
      else:
        result = values.TowerLocalVariable(
            index, index[devices[0]], tower_local)

    if not context.executing_eagerly():
      g = ops.get_default_graph()
      # If "trainable" is True, next_creator() will add the member variables
      # to the TRAINABLE_VARIABLES collection, so we manually remove
      # them and replace with the MirroredVariable. We can't set
      # "trainable" to False for next_creator() since that causes functions
      # like implicit_gradients to skip those variables.
      if kwargs.get("trainable", True):
        collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
        l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
        for v in index.values():
          l.remove(v)
      g.add_to_collections(collections, result)
    return result
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:58,代码来源:mirrored_strategy.py

示例2: _create_tpu_mirrored_variable

def _create_tpu_mirrored_variable(  # pylint: disable=missing-docstring
    strategy, device_map, logical_device, real_mirrored_creator,
    *args, **kwargs):
  # Figure out what collections this variable should be added to.
  # We'll add the TPUMirroredVariable to those collections instead.
  var_collections = kwargs.pop("collections", None)
  if var_collections is None:
    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  # TODO(jhseu): Should we have different behavior for different
  # synchronization settings?

  # Get aggregation value
  # TODO(jhseu): Support aggregation in a replica context.
  aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
  if aggregation not in [
      vs.VariableAggregation.NONE,
      vs.VariableAggregation.SUM,
      vs.VariableAggregation.MEAN,
      vs.VariableAggregation.ONLY_FIRST_REPLICA,
  ]:
    raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
                     .format(aggregation, kwargs["name"]))

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device", None)

  # TODO(josh11b,apassos): It would be better if variable initialization
  # was never recorded on the tape instead of having to do this manually
  # here.
  with tape.stop_recording():
    devices = device_map.logical_to_actual_devices(logical_device)
    value_list = real_mirrored_creator(devices, *args, **kwargs)
    result = values.TPUMirroredVariable(
        strategy, device_map, value_list, aggregation,
        logical_device=logical_device)

  if not (context.executing_eagerly() or ops.inside_function()):
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable", True):
      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for v in value_list:
        l.remove(v)
    g.add_to_collections(var_collections, result)
  return result
开发者ID:jackd,项目名称:tensorflow,代码行数:52,代码来源:tpu_strategy.py

示例3: compute_gradients

def compute_gradients(model, images, labels, num_replicas=1):
  with tf.GradientTape() as grad_tape:
    logits = model(images, training=True)
    loss = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=labels)
    tf.contrib.summary.scalar(name='loss', tensor=loss)
    if num_replicas != 1:
      loss /= num_replicas

  # TODO(b/110991947): We can mistakenly trace the gradient call in
  # multi-threaded environment. Explicitly disable recording until
  # this is fixed.
  with tape.stop_recording():
    grads = grad_tape.gradient(loss, model.variables)
  return grads
开发者ID:zhuochenKIDD,项目名称:tensorflow,代码行数:15,代码来源:resnet50_test.py

示例4: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:47,代码来源:custom_gradient.py

示例5: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
开发者ID:1000sprites,项目名称:tensorflow,代码行数:22,代码来源:custom_gradient.py

示例6: _create_mirrored_variable

def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):  # pylint: disable=g-missing-docstring
  # Figure out what collections this variable should be added to.
  # We'll add the MirroredVariable to those collections instead.
  collections = kwargs.pop("collections", None)
  if collections is None:
    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  # Get synchronization value
  synchronization = kwargs.get("synchronization",
                               variable_scope.VariableSynchronization.ON_WRITE)
  if synchronization == variable_scope.VariableSynchronization.NONE:
    raise ValueError("`NONE` variable synchronization mode is not "
                     "supported with `Mirrored` distribution strategy. Please"
                     " change the `synchronization` for variable: " +
                     kwargs["name"])
  elif synchronization == variable_scope.VariableSynchronization.ON_READ:
    # Variables that are to be synced on read are tower local.
    is_tower_local = True
    kwargs["trainable"] = False
  elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
        synchronization == variable_scope.VariableSynchronization.AUTO):
    # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
    is_tower_local = False
  else:
    raise ValueError("Invalid variable synchronization mode: " +
                     synchronization + " for variable: " + kwargs["name"])

  # Get aggregation value
  aggregation = kwargs.pop("aggregation",
                           variable_scope.VariableAggregation.NONE)
  if aggregation not in (
      variable_scope.VariableAggregation.NONE,
      variable_scope.VariableAggregation.SUM,
      variable_scope.VariableAggregation.MEAN,
      variable_scope.VariableAggregation.ONLY_FIRST_TOWER
  ):
    raise ValueError("Invalid variable aggregation mode: " + aggregation +
                     " for variable: " + kwargs["name"])

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device", None)

  # TODO(josh11b,apassos): It would be better if variable initialization
  # was never recorded on the tape instead of having to do this manually
  # here.
  with tape.stop_recording():
    index = real_mirrored_creator(devices, *args, **kwargs)

    if is_tower_local:
      result = values.TowerLocalVariable(index, index[devices[0]], aggregation)
    else:
      result = values.MirroredVariable(index, index[devices[0]], aggregation)

  # Add the wrapped variable to the requested collections.
  # The handling of eager mode and the global step matches
  # ResourceVariable._init_from_args().
  if not context.executing_eagerly():
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable", True):
      collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for v in index.values():
        l.remove(v)
    g.add_to_collections(collections, result)
  elif ops.GraphKeys.GLOBAL_STEP in collections:
    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

  return result
开发者ID:Jordan1237,项目名称:tensorflow,代码行数:74,代码来源:mirrored_strategy.py

示例7: __call__

  def __call__(self, *args, **kwds):
    """Calls the graph function."""
    if self._created_variables:
      # In this case we have created variables on the first call, so we run the
      # defunned version which is guaranteed to never create variables.
      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    elif self._stateful_fn is not None:
      # In this case we have not created variables on the first call. So we can
      # run the first trace but we should fail if variables are created.
      results = self._stateful_fn(*args, **kwds)
      if self._created_variables:
        raise ValueError("Creating variables on a non-first call to a function"
                         " decorated with tf.function.")
      return results

    # This is the first call of __call__, so we have to initialize.
    self._initialize(args, kwds)
    if self._lifted_all_initializers and self._lifted_placeholders:
      with ops.init_scope():
        handles, placeholders = zip(*self._lifted_placeholders)
        if context.executing_eagerly():
          lifted_fn = function_lib._EagerDefinedFunction(  # pylint: disable=protected-access
              "initializer" + str(ops.uid()),
              self._lifted_initializer_graph,
              placeholders, [], {})
          with tape.stop_recording():
            lifted_fn.call(context.context(), list(handles))
      return self._stateless_fn(*args, **kwds)
    canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)

    if not self._created_variables:
      # If we did not create any variables the trace we have is good enough.
      return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access

    def fn_with_cond(*inner_args, **inner_kwds):
      """Conditionally runs initialization if it's needed."""
      condition = True
      for wr in self._created_variables:
        variable = wr()
        if variable is None:
          raise ValueError(
              "A tf.Variable created inside your tf.function has been"
              " garbage-collected. Your code needs to keep Python references"
              " to variables created inside `tf.function`s.\n"
              "\n"
              "A common way to raise this error is to create and return a"
              " variable only referenced inside your function:\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  v = tf.Variable(1.0)\n"
              "  return v\n"
              "\n"
              "v = f()  # Crashes with this error message!\n"
              "\n"
              "The reason this crashes is that @tf.function annotated"
              " function returns a **`tf.Tensor`** with the **value** of the"
              " variable when the function is called rather than the"
              " variable instance itself. As such there is no code holding a"
              " reference to the `v` created inside the function and Python"
              " garbage collects it.\n"
              "\n"
              "The simplest way to fix this issue is to create variables"
              " outside the function and capture them:\n"
              "\n"
              "v = tf.Variable(1.0)\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  return v\n"
              "\n"
              "f()  # <tf.Tensor: ... numpy=1.>\n"
              "v.assign_add(1.)\n"
              "f()  # <tf.Tensor: ... numpy=2.>")
        condition = math_ops.logical_and(
            condition, resource_variable_ops.var_is_initialized_op(
                variable.handle))
      # We want to call stateless_fn if possible because it avoids recomputing
      # potentially expensive initializers.
      return control_flow_ops.cond(
          condition,
          lambda: self._stateless_fn(*inner_args, **inner_kwds),
          functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
                            inner_args, inner_kwds))

    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
开发者ID:gautam1858,项目名称:tensorflow,代码行数:86,代码来源:def_function.py

示例8: _real_mirrored_creator

    def _real_mirrored_creator(devices, *args, **kwargs):
      """Creates one MirroredVariable on the current worker."""
      unique_var_name = ops.get_default_graph().unique_name(
          kwargs["name"], mark_as_used=False).rstrip("/")
      # pylint: disable=protected-access
      collective_instance_key = self._collective_keys.get_instance_key(
          key_id=unique_var_name)
      # Only the first device participles in the broadcast of initial values.
      group_key = self._collective_keys.get_group_key([devices[0]])
      group_size = self._num_workers
      if "initial_value" not in kwargs:
        raise ValueError("Initial value must be specified.")
      initial_value = kwargs["initial_value"]
      if callable(initial_value):
        initial_value_fn = initial_value
      else:
        initial_value_fn = lambda: initial_value

      value_list = []
      for i, d in enumerate(devices):
        with ops.init_scope(), ops.device(d):
          if i == 0:
            # The initial value fn makes sure variables all initialized to
            # same values. The first device of the chief worker will send their
            # variable values to other workers.
            def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
              with ops.device(device):
                initial_value = initial_value_fn()
                assert not callable(initial_value)
                initial_value = ops.convert_to_tensor(initial_value)

                assert index == 0, index
                if self._num_workers > 1:
                  if self._is_chief:
                    bcast_send = collective_ops.broadcast_send(
                        initial_value, initial_value.shape, initial_value.dtype,
                        group_size, group_key, collective_instance_key)
                    with ops.control_dependencies([bcast_send]):
                      return array_ops.identity(initial_value)
                  else:
                    return collective_ops.broadcast_recv(
                        initial_value.shape, initial_value.dtype, group_size,
                        group_key, collective_instance_key)
                return initial_value
          else:
            # Give replicas meaningful distinct names:
            var0name = value_list[0].name.split(":")[0]
            # We append a / to variable names created on replicas with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)

            # Variables on non-first replica get initial values from the
            # variables created on the first device of each worker.
            def _overridden_initial_value_fn(device=d, index=i):
              assert index > 0
              with ops.device(device):
                if context.executing_eagerly():
                  return array_ops.identity(value_list[0].value())
                else:
                  return array_ops.identity(value_list[0].initial_value)

          kwargs["initial_value"] = _overridden_initial_value_fn
          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
            # Don't record operations (e.g. other variable reads) during
            # variable creation.
            with tape.stop_recording():
              v = next_creator(*args, **kwargs)

          if i == 0:
            actual_var_name = v.name.split(":")[0]
            assert unique_var_name == actual_var_name, "%r vs %r" % (
                unique_var_name, actual_var_name)
          assert not isinstance(v, values.DistributedVariable)
          value_list.append(v)
      return value_list
开发者ID:aritratony,项目名称:tensorflow,代码行数:76,代码来源:collective_all_reduce_strategy.py

示例9: _create_variable

  def _create_variable(self, next_creator, *args, **kwargs):
    """Create a mirrored variable. See `DistributionStrategy.scope`."""
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
      raise ValueError("`NONE` variable synchronization mode is not "
                       "supported with `Mirrored` distribution strategy. Please"
                       " change the `synchronization` for variable: " +
                       kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
      # Variables that are to be synced on read are tower local.
      is_tower_local = True
      kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
          synchronization == variable_scope.VariableSynchronization.AUTO):
      # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
      is_tower_local = False
    else:
      raise ValueError("Invalid variable synchronization mode: " +
                       synchronization + " for variable: " + kwargs["name"])

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in [
        variable_scope.VariableAggregation.NONE,
        variable_scope.VariableAggregation.SUM,
        variable_scope.VariableAggregation.MEAN
    ]:
      raise ValueError("Invalid variable aggregation mode: " + aggregation +
                       " for variable: " + kwargs["name"])

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
      index = {}
      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            # We append a / to variable names created on towers with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
            # Initialize replicas with the same value:
            if context.executing_eagerly():
              kwargs["initial_value"] = array_ops.identity(
                  index[devices[0]].value())
            else:
              def initial_value_fn(device=d):
                with ops.device(device):
                  return array_ops.identity(index[devices[0]].initial_value)
              kwargs["initial_value"] = initial_value_fn
          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)
          assert not isinstance(v, values.DistributedVariable)
          index[d] = v

      if is_tower_local:
        result = values.TowerLocalVariable(index, index[devices[0]],
                                           aggregation)
      else:
        result = values.MirroredVariable(index, index[devices[0]], aggregation)

    if not context.executing_eagerly():
      g = ops.get_default_graph()
      # If "trainable" is True, next_creator() will add the member variables
      # to the TRAINABLE_VARIABLES collection, so we manually remove
      # them and replace with the MirroredVariable. We can't set
      # "trainable" to False for next_creator() since that causes functions
      # like implicit_gradients to skip those variables.
      if kwargs.get("trainable", True):
        collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
        l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
        for v in index.values():
          l.remove(v)
      g.add_to_collections(collections, result)
    return result
开发者ID:dan-lennox,项目名称:tensorflow,代码行数:94,代码来源:mirrored_strategy.py


注:本文中的tensorflow.python.eager.tape.stop_recording函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。