当前位置: 首页>>代码示例>>Python>>正文


Python graph_editor.reroute_ts函数代码示例

本文整理汇总了Python中tensorflow.contrib.graph_editor.reroute_ts函数的典型用法代码示例。如果您正苦于以下问题:Python reroute_ts函数的具体用法?Python reroute_ts怎么用?Python reroute_ts使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了reroute_ts函数的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_reroute

  def test_reroute(self):
    ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
    self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
    self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))

    ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
    self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
    self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:8,代码来源:reroute_test.py

示例2: _FoldUnfusedBatchNorms

def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
  """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
  input_to_ops_map = input_to_ops.InputToOps(graph)

  for bn in common.BatchNormGroups(graph):
    has_scaling = _HasScaling(graph, input_to_ops_map, bn)

    if not _IsValidUnfusedBatchNorm(graph, bn):
      continue

    # The mangling code intimately depends on BatchNorm node's internals.
    original_op, folded_op = _CreateFoldedOp(
        graph,
        bn,
        has_scaling=has_scaling,
        freeze_batch_norm_delay=freeze_batch_norm_delay,
        is_training=is_training)

    activation = common.GetEndpointActivationOp(graph, bn)
    if activation:
      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                     [original_op.outputs[0]],
                                                     can_modify=[activation])
      if nodes_modified_count != 1:
        raise ValueError('Unexpected inputs to op: %s' % activation.name)
      continue

    # Treat consumer ops in bypass modules differently since they have Add
    # operations instead of Relu* above.
    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                   [original_op.outputs[0]],
                                                   can_modify=[add_bypass])
    if nodes_modified_count != 1:
      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:49,代码来源:fold_batch_norms.py

示例3: _FoldFusedBatchNorms

def _FoldFusedBatchNorms(graph):
  """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
  for match in _FindFusedBatchNorms(graph):
    scope, sep, _ = match.layer_op.name.rpartition('/')
    # Make sure new ops are added to `graph` and put on the same device as
    # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
    # named `scope`. Otherwise, TF creates a unique scope whose name starts with
    # `scope`.
    with graph.as_default(), graph.name_scope(scope + sep), ops.device(
        match.bn_op.device):
      with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
        # new weights = old weights * gamma / sqrt(variance + epsilon)
        # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
        multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
            match.variance_tensor + match.bn_op.get_attr('epsilon'))
        bias_tensor = math_ops.subtract(
            match.beta_tensor,
            match.mean_tensor * multiplier_tensor,
            name='bias')

        # The shape of depthwise weights is different, so we need to reshape the
        # multiplier_tensor to ensure that the scaled_weight_tensor has the
        # expected shape.
        if match.layer_op.type == 'DepthwiseConv2dNative':
          new_shape = [
              match.weight_tensor.get_shape().as_list()[2],
              match.weight_tensor.get_shape().as_list()[3]
          ]
          multiplier_tensor = array_ops.reshape(
              multiplier_tensor, new_shape, name='scale_reshape')

      # TODO(suharshs): This naming of the following ops needs to carefully
      # follow the naming expected by quantize.py. Generalize the quantize code
      # to not require these delicate naming conventions.
      scaled_weight_tensor = math_ops.multiply(
          match.weight_tensor, multiplier_tensor, name='mul_fold')

      new_layer_tensor = _CloneWithNewOperands(
          match.layer_op, match.input_tensor, scaled_weight_tensor)

      bias_add_tensor = math_ops.add(
          new_layer_tensor, bias_tensor, name='add_fold')

      nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
                                                     match.output_tensor)
      if nodes_modified_count != 1:
        raise ValueError(
            'Unexpected inputs to op: %s' % match.output_tensor.name)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:58,代码来源:fold_batch_norms.py

示例4: FoldBatchNorms

def FoldBatchNorms(graph):
  """Finds batch norm layers in the graph, folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
  # Fail immediately when the graph contains unsupported fused batch norm ops.
  if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
    raise ValueError('Fused batch norm is not supported')

  input_to_ops_map = input_to_ops.InputToOps(graph)

  for bn in common.BatchNormGroups(graph):
    has_scaling = _HasScaling(graph, input_to_ops_map, bn)

    # The mangling code intimately depends on BatchNorm node's internals.
    original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling)

    activation = common.GetEndpointActivationOp(graph, bn)
    if activation:
      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                     [original_op.outputs[0]],
                                                     can_modify=[activation])
      if nodes_modified_count != 1:
        raise ValueError('Unexpected inputs to op: %s' % activation.name)
      continue

    # Treat consumer ops in bypass modules differently since they have Add
    # operations instead of Relu* above.
    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                   [original_op.outputs[0]],
                                                   can_modify=[add_bypass])
    if nodes_modified_count != 1:
      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
开发者ID:alexsax,项目名称:tensorflow,代码行数:42,代码来源:fold_batch_norms.py

示例5: _InsertQuantOp


#.........这里部分代码省略.........
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
  if producer_scope and not producer.name.startswith(producer_scope):
    logging.info(
        '_InsertQuantOp ignores context="%s" name="%s" '
        'because producer "%s" is not in scope "%s"',
        context, name, producer.name, producer_scope)
    return

  if consumer_scope:
    consumers_in_scope = []
    for consumer in consumers:
      if consumer.name.startswith(consumer_scope):
        consumers_in_scope.append(consumer)
      else:
        logging.info(
            '_InsertQuantOp context="%s" name="%s" ignores '
            'consumer "%s" because it is not in scope "%s"',
            context, name, consumer.name, consumer_scope)
        return
    consumers = consumers_in_scope

  name_prefix = _AddContextToName(context, name)
  # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
  # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
  # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
  # breaks things later.
  name_scope = ops.get_name_scope()
  if name_scope:
    name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

  inputs = producer.outputs[0]
  # Prevent ops from being quantized multiple times. Bypass ops can sometimes
  # overlap between multiple matches, so we need to ensure that we don't
  # add duplicate FakeQuant operations.
  fake_quant_ops = set([
      'FakeQuantWithMinMaxVars',
      'FakeQuantWithMinMaxArgs'
  ])
  if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
    return

  if moving_avg:
    quant = (
        quant_ops.MovingAvgQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            ema_decay=ema_decay,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))
  else:
    quant = (
        quant_ops.LastValueQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))

  if quant_delay and quant_delay > 0:
    activate_quant = math_ops.greater_equal(
        common.CreateOrGetQuantizationStep(),
        quant_delay,
        name=name_prefix + '/activate_quant')
    quant = control_flow_ops.cond(
        activate_quant,
        lambda: quant,
        lambda: inputs,
        name=name_prefix + '/delayed_quant')

  if consumers:
    tensors_modified_count = graph_editor.reroute_ts(
        [quant], [inputs], can_modify=consumers)
    # Some operations can have multiple output tensors going to the same
    # consumer. Since consumers is a set, we need to ensure that
    # tensors_modified_count is greater than or equal to the length of the set
    # of consumers.
    if tensors_modified_count < len(consumers):
      raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
          [consumer.name for consumer in consumers]))
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:101,代码来源:quantize.py

示例6: _FoldFusedBatchNorms

def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
  """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, true if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
  for match in _FindFusedBatchNorms(graph):
    scope, sep, _ = match.layer_op.name.rpartition('/')
    # Make sure new ops are added to `graph` and put on the same device as
    # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
    # named `scope`. Otherwise, TF creates a unique scope whose name starts with
    # `scope`.
    with graph.as_default(), graph.name_scope(scope + sep):
      with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
        # new weights = old weights * gamma / sqrt(variance + epsilon)
        # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
        multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
            match.variance_tensor + match.bn_op.get_attr('epsilon'))
        bias_tensor = math_ops.subtract(
            match.beta_tensor,
            match.mean_tensor * multiplier_tensor,
            name='bias')

        correction_scale, correction_recip, correction_offset = None, None, None
        if is_training:
          correction_scale, correction_recip, correction_offset = (
              _ComputeBatchNormCorrections(
                  context='',
                  match=match,
                  freeze_batch_norm_delay=freeze_batch_norm_delay,
                  fused_batch_norm=True))
        # The shape of depthwise weights is different, so we need to reshape the
        # multiplier_tensor to ensure that the scaled_weight_tensor has the
        # expected shape.
        weights = match.weight_tensor
        if match.layer_op.type == 'DepthwiseConv2dNative':
          new_shape = [
              match.weight_tensor.get_shape().as_list()[2],
              match.weight_tensor.get_shape().as_list()[3]
          ]
          multiplier_tensor = array_ops.reshape(
              multiplier_tensor, new_shape, name='scale_reshape')

          if correction_scale is not None:
            correction_scale = array_ops.reshape(
                correction_scale, new_shape, name='correction_reshape')

      if correction_scale is not None:
        weights = math_ops.multiply(
            correction_scale, weights, name='correction_mult')

      scaled_weight_tensor = math_ops.multiply(
          weights, multiplier_tensor, name='mul_fold')
      new_layer_tensor = _CloneWithNewOperands(
          match.layer_op, match.input_tensor, scaled_weight_tensor)

      if correction_recip is not None:
        new_layer_tensor = math_ops.multiply(
            correction_recip, new_layer_tensor, name='post_conv_mul')
        new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset),
                                        'correction_add')

      bias_add_tensor = math_ops.add(
          new_layer_tensor, bias_tensor, name='add_fold')

      nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
                                                     match.output_tensor)
      if nodes_modified_count == 0:
        raise ValueError('Folding batch norms failed, %s had no outputs.' %
                         match.output_tensor.name)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:79,代码来源:fold_batch_norms.py

示例7: _ComputeBatchNormCorrections

def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
                                 fused_batch_norm):
  """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.
    fused_batch_norm: Bool, true if fused batch norm is used.

  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

  g = ops.get_default_graph()
  prefix = '' if not context else context + '/'
  with g.name_scope(prefix + 'batch_norm_correction'):
    recip_sigma_mv = math_ops.rsqrt(
        match.moving_variance_tensor + match.batch_epsilon)
    recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon)
    correction_scale = math_ops.divide(
        recip_sigma_mv, recip_sigma, name='scale_compute')
    correction_scale = array_ops.identity(
        correction_scale, name='correction_scale')
    correction_recip = math_ops.reciprocal(
        correction_scale, name='reciprocal_compute')
    correction_offset = math_ops.multiply(
        match.gamma_tensor,
        match.mean_tensor * recip_sigma -
        match.moving_mean_tensor * recip_sigma_mv,
        name='offset_compute')

    if freeze_batch_norm_delay is not None:
      use_mv_avg = math_ops.greater_equal(
          common.CreateOrGetQuantizationStep(),
          freeze_batch_norm_delay,
          name='use_moving_average')
    else:
      use_mv_avg = False

    bn_decay_zero = 0.0
    bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
    bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

    bn_decay_mean_out = utils.smart_cond(
        use_mv_avg,
        lambda: bn_decay_zero,
        lambda: match.bn_decay_mean_tensor,
        name='freeze_moving_mean')
    graph_editor.reroute_ts(
        [bn_decay_mean_out], [match.bn_decay_mean_tensor],
        can_modify=bn_decay_mean_consumers)

    if fused_batch_norm is False:
      bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
      bn_decay_var_out = utils.smart_cond(
          use_mv_avg,
          lambda: bn_decay_zero,
          lambda: match.bn_decay_var_tensor,
          name='freeze_moving_var')
      graph_editor.reroute_ts(
          [bn_decay_var_out], [match.bn_decay_var_tensor],
          can_modify=bn_decay_var_consumers)

    correction_recip = utils.smart_cond(
        use_mv_avg,
        lambda: array_ops.ones(correction_scale.shape),
        lambda: correction_recip,
        name='correction_recip')

    correction_offset = utils.smart_cond(
        use_mv_avg,
        lambda: correction_offset,
        lambda: array_ops.zeros(correction_offset.shape),
        name='correction_offset')
  return correction_scale, correction_recip, correction_offset
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:99,代码来源:fold_batch_norms.py

示例8: test_compatibility

 def test_compatibility(self):
   with self.assertRaises(ValueError):
     ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
开发者ID:1000sprites,项目名称:tensorflow,代码行数:3,代码来源:reroute_test.py

示例9: _InsertQuantOp

  def _InsertQuantOp(
      self,
      context,
      producer,
      consumers,
      name,
      moving_avg=True,
      init_min=-6.0,
      init_max=6.0,
      delay_requested=True,
      bits=8,
      narrow_range=False,):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

    Args:
      context: Context where producer and consumer operations are nested.
      producer: Producer operation of the pairs where quantization will be
        inserted.
      consumers: Consumer operations of the pairs.
      name: Name for the new quantization op within the context.
      moving_avg: Specifies whether to use exponential moving average or just
        the last value seen.
      init_min: Starting minimum value for the new quantization op.
      init_max: Starting maximum value for the new quantization op.
      delay_requested: If true, implement quantization delay where needed.
        False value explicitly disables delay quantization everywhere.
      bits: Number of bits to use for quantization, must be between 2 and 8.
      narrow_range: Whether to use the narrow quantization range
        [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    Raises:
      ValueError: When producer operation is not directly connected to the
        consumer operation.
    """
    scope = context + '/' + name
    inputs = producer.outputs[0]
    if moving_avg:
      quant = (quant_ops.MovingAvgQuantize(
          inputs,
          init_min=init_min,
          init_max=init_max,
          ema_decay=self.ema_decay,
          is_training=self.is_training,
          num_bits=bits,
          narrow_range=narrow_range,
          updates_collection=_UPDATE_QUANT_OPS,
          vars_collection=self.vars_collection,
          scope=scope))
    else:
      quant = (quant_ops.LastValueQuantize(
          inputs,
          init_min=init_min,
          init_max=init_max,
          is_training=self.is_training,
          num_bits=bits,
          narrow_range=narrow_range,
          updates_collection=_UPDATE_QUANT_OPS,
          vars_collection=self.vars_collection,
          scope=scope))

    if delay_requested and self.quant_delay and self.quant_delay > 0:
      activate_quant = math_ops.greater_equal(
          training_util.get_or_create_global_step(),
          self.quant_delay,
          name=scope + '/activate_quant')
      quant = control_flow_ops.cond(
          activate_quant,
          lambda: quant,
          lambda: inputs,
          name=scope + '/delayed_quant')

    nodes_modified_count = graph_editor.reroute_ts(
        [quant], [inputs], can_modify=consumers)
    if nodes_modified_count != len(consumers):
      raise ValueError('Some inputs not quantized for ops: [%s]' %
                       ', '.join([consumer.name for consumer in consumers]))
开发者ID:SylChan,项目名称:tensorflow,代码行数:75,代码来源:quantize.py

示例10: _InsertQuantOp

def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
                   narrow_range=False):
  """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context w,here producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
  name_prefix = _AddContextToName(context, name)
  inputs = producer.outputs[0]
  if moving_avg:
    quant = (
        quant_ops.MovingAvgQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            ema_decay=ema_decay,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))
  else:
    quant = (
        quant_ops.LastValueQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))

  if quant_delay and quant_delay > 0:
    activate_quant = math_ops.greater_equal(
        common.CreateOrGetQuantizationStep(),
        quant_delay,
        name=name_prefix + '/activate_quant')
    quant = control_flow_ops.cond(
        activate_quant,
        lambda: quant,
        lambda: inputs,
        name=name_prefix + '/delayed_quant')

  nodes_modified_count = graph_editor.reroute_ts(
      [quant], [inputs], can_modify=consumers)
  if nodes_modified_count != len(consumers):
    raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
        [consumer.name for consumer in consumers]))
开发者ID:dananjayamahesh,项目名称:tensorflow,代码行数:83,代码来源:quantize.py

示例11: gradients


#.........这里部分代码省略.........
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print("Warning, some output nodes are also checkpoints nodes: {}".format(
            format_ops(ys_intersect_checkpoints)))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints, within_ops=fwd_ops)
    debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
        len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints))
    debug_print("ops_to_copy = {}".format(ops_to_copy))
    debug_print("Processing list {}".format(ys))
    _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired {} in place of {} restricted to {}".format(
        checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops))

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
    debug_print("Got gradients {}".format(dv))
    debug_print("for %s", copied_ys)
    debug_print("with respect to {}".format(boundary+xs))

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
                                            dv[:len(checkpoints_disconnected)])}
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts))
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]
开发者ID:stonezuohui,项目名称:faceswap,代码行数:67,代码来源:memory_saving_gradients.py


注:本文中的tensorflow.contrib.graph_editor.reroute_ts函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。