当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.while方法代码示例

本文整理汇总了Python中tensorflow.while方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.while方法的具体用法?Python tensorflow.while怎么用?Python tensorflow.while使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.while方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testUseWithinWhileLoop

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import while [as 别名]
def testUseWithinWhileLoop(self):
    with tf.Graph().as_default():
      spec = hub.create_module_spec(double_module_fn)
      m = hub.Module(spec)
      i = tf.constant(0)
      x = tf.constant(10.0)
      p = tf_v1.placeholder(dtype=tf.int32)
      c = lambda i, x: tf.less(i, p)
      b = lambda i, x: (tf.add(i, 1), m(x))
      oi, ox = tf.while_loop(c, b, [i, x])  # ox = v**p * x
      v = m.variables[0]
      dodv = tf.gradients(ox, v)[0]  # d ox / dv = p*v**(p-1) * x
      dodx = tf.gradients(ox, x)[0]  # d ox / dx = v**p
      with tf_v1.Session() as sess:
        sess.run(tf_v1.global_variables_initializer())
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 1}), [1, 20])
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 2}), [2, 40])
        self.assertAllEqual(sess.run([oi, ox], feed_dict={p: 4}), [4, 160])
        # Gradients also use the control flow structures setup earlier.
        # Also check they are working properly.
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 1}), [10, 2])
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 2}), [40, 4])
        self.assertAllEqual(sess.run([dodv, dodx], feed_dict={p: 4}), [320, 16])

  # tf.map_fn() is merely a wrapper around tf.while(), but just to be sure... 
开发者ID:tensorflow,项目名称:hub,代码行数:27,代码来源:native_module_test.py

示例2: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import while [as 别名]
def __init__(self, input_type=None, output_type=None, name_or_scope=None):
    """Creates the layer.

    Args:
      input_type: A type.
      output_type: A type.
      name_or_scope: A string or variable scope. If a string, a new variable
        scope will be created by calling
        [`create_variable_scope`](#create_variable_scope), with defaults
        inherited from the current variable scope. If no caching device is set,
        it will be set to `lambda op: op.device`. This is because `tf.while` can
        be very inefficient if the variables it uses are not cached locally.
    """
    if name_or_scope is None: name_or_scope = type(self).__name__
    if isinstance(name_or_scope, tf.VariableScope):
      self._vscope = name_or_scope
      name = str(self._vscope.name)
    elif isinstance(name_or_scope, six.string_types):
      self._vscope = create_variable_scope(name_or_scope)
      name = name_or_scope
    else:
      raise TypeError('name_or_scope must be a tf.VariableScope or a string: '
                      '%s' % (name_or_scope,))
    if self._vscope.caching_device is None:
      self._vscope.set_caching_device(lambda op: op.device)
    super(Layer, self).__init__(input_type, output_type, name)

    if not hasattr(self, '_constructor_name'):
      self._constructor_name = '__.%s' % self.__class__.__name__
    if not hasattr(self, '_constructor_args'):
      self._constructor_args = None
    if not hasattr(self, '_constructor_kwargs'):
      self._constructor_kwargs = None 
开发者ID:tensorflow,项目名称:fold,代码行数:35,代码来源:layers.py

示例3: forward

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import while [as 别名]
def forward(self, image, reference):
		'''Evaluates distances between images in 'image' and 'reference' (data in NHWC order).
		   Returns an N-element distance vector.
		   
		   If 'image' is a tuple, evaluates all the images in the tuple with the same input transformations
		   and dropout as 'reference'. A different set of input transformations for each would result in
		   unnecessary uncertainty in determining which of the images is closest to the reference. The
		   returned value is a tuple of N-element distance vectors.'''
		  
		if isinstance(image, list):
			raise Exception('Parameter \'image\' must be a tensor or a tuple of tensors.')
		
		image_in = as_tuple(image)
		
		def cond(i, loss_sum):
			return tf.less(i, tf.cast(self.config.average_over, tf.int32))
		
		def body(i, loss_sum):
			ensemble = self.sample_ensemble(self.config)
			
			ensemble_X = for_each(image_in, lambda X: apply_ensemble(self.config, ensemble, X))
			ensemble_X = for_each(ensemble_X, lambda X: 2.0 * X - 1.0)

			ensemble_R = apply_ensemble(self.config, ensemble, reference)			
			ensemble_R = 2.0 * ensemble_R - 1.0
			
			loss = self.network.forward(ensemble_X, ensemble_R)
			loss_sum += tf.stack(loss, axis=0)
			
			loss_sum.set_shape([len(image_in), self.config.batch_size])
			
			return i+1, loss_sum

		if isinstance(self.config.average_over, numbers.Number) and self.config.average_over == 1:
			# Skip tf.while for trivial single iterations.
			_, loss_sum = body(0, tf.zeros([len(image_in), self.config.batch_size], dtype=self.config.dtype))
		else:
			# Run multiple times for any other average_over count.
			_, loss_sum = tf.while_loop(cond, body, (0, tf.zeros([len(image_in), self.config.batch_size], dtype=self.config.dtype)), back_prop=self.back_prop)
			loss_sum /= tf.cast(self.config.average_over, self.config.dtype)

		
		if isinstance(image, tuple):
			return tuple((loss_sum[i, :] for i in range(len(image))))
		else:
			return tf.reshape(loss_sum, [self.config.batch_size]) 
开发者ID:mkettune,项目名称:elpips,代码行数:48,代码来源:elpips.py


注:本文中的tensorflow.while方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。