当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.assert_equal函数代码示例

本文整理汇总了Python中tensorflow.assert_equal函数的典型用法代码示例。如果您正苦于以下问题:Python assert_equal函数的具体用法?Python assert_equal怎么用?Python assert_equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了assert_equal函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: CombineArcAndRootPotentials

def CombineArcAndRootPotentials(arcs, roots):
  """Combines arc and root potentials into a single set of potentials.

  Args:
    arcs: [B,N,N] tensor of batched arc potentials.
    roots: [B,N] matrix of batched root potentials.

  Returns:
    [B,N,N] tensor P of combined potentials where
      P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t]
  """
  # All arguments must have statically-known rank.
  check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
  check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix')

  # All arguments must share the same type.
  dtype = arcs.dtype.base_dtype
  check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch')

  roots_shape = tf.shape(roots)
  arcs_shape = tf.shape(arcs)
  batch_size = roots_shape[0]
  num_tokens = roots_shape[1]
  with tf.control_dependencies([
      tf.assert_equal(batch_size, arcs_shape[0]),
      tf.assert_equal(num_tokens, arcs_shape[1]),
      tf.assert_equal(num_tokens, arcs_shape[2])]):
    return tf.matrix_set_diag(arcs, roots)
开发者ID:ALISCIFP,项目名称:models,代码行数:28,代码来源:digraph_ops.py

示例2: square_error

 def square_error(estimated, target):
     with tf.name_scope('evaluation'):
         with tf.control_dependencies([tf.assert_equal(count(tf.to_int32(target) - tf.to_int32(target)), 0.)]):
             tf.assert_equal(count(tf.cast(target - estimated, tf.int32)), 0.)
             squared_difference = tf.pow(estimated - target, 2, name='squared_difference')
             square_error = tf.reduce_sum(squared_difference, name='summing_square_errors')
             square_error = tf.to_float(square_error)
             return square_error
开发者ID:MehdiAB161,项目名称:Autoencoder-Stability,代码行数:8,代码来源:Evaluation.py

示例3: logp

 def logp(self, F, Y):
     with tf.control_dependencies(
             [
                 tf.assert_equal(tf.shape(Y)[1], 1),
                 tf.assert_equal(tf.cast(tf.shape(F)[1], settings.int_type),
                                 tf.cast(self.num_classes, settings.int_type))
             ]):
         return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=F, labels=Y[:, 0])[:, None]
开发者ID:sanket-kamthe,项目名称:GPflow,代码行数:8,代码来源:likelihoods.py

示例4: prepare_serialized_examples

  def prepare_serialized_examples(self, serialized_example,
      max_quantized_value=2, min_quantized_value=-2):

    contexts, features = tf.parse_single_sequence_example(
        serialized_example,
        context_features={"id": tf.FixedLenFeature(
            [], tf.string),
                          "labels": tf.VarLenFeature(tf.int64)},
        sequence_features={
            feature_name : tf.FixedLenSequenceFeature([], dtype=tf.string)
            for feature_name in self.feature_names
        })

    # read ground truth labels
    labels = (tf.cast(
        tf.sparse_to_dense(contexts["labels"].values, (self.num_classes,), 1,
            validate_indices=False),
        tf.bool))

    # loads (potentially) different types of features and concatenates them
    num_features = len(self.feature_names)
    assert num_features > 0, "No feature selected: feature_names is empty!"

    assert len(self.feature_names) == len(self.feature_sizes), \
    "length of feature_names (={}) != length of feature_sizes (={})".format( \
    len(self.feature_names), len(self.feature_sizes))

    num_frames = -1  # the number of frames in the video
    feature_matrices = [None] * num_features  # an array of different features
    for feature_index in range(num_features):
      feature_matrix, num_frames_in_this_feature = self.get_video_matrix(
          features[self.feature_names[feature_index]],
          self.feature_sizes[feature_index],
          self.max_frames,
          max_quantized_value,
          min_quantized_value)
      if num_frames == -1:
        num_frames = num_frames_in_this_feature
      else:
        tf.assert_equal(num_frames, num_frames_in_this_feature)

      feature_matrices[feature_index] = feature_matrix

    # cap the number of frames at self.max_frames
    num_frames = tf.minimum(num_frames, self.max_frames)

    # concatenate different features
    video_matrix = tf.concat(feature_matrices, 1)

    # convert to batch format.
    # TODO: Do proper batch reads to remove the IO bottleneck.
    batch_video_ids = tf.expand_dims(contexts["id"], 0)
    batch_video_matrix = tf.expand_dims(video_matrix, 0)
    batch_labels = tf.expand_dims(labels, 0)
    batch_frames = tf.expand_dims(num_frames, 0)

    return batch_video_ids, batch_video_matrix, batch_labels, batch_frames
开发者ID:vijayky88,项目名称:youtube-8m,代码行数:57,代码来源:readers.py

示例5: discretized_mix_logistic_loss

def discretized_mix_logistic_loss(y_hat, y, num_classes=256,
		log_scale_min=-7.0, reduce=True):
	'''Discretized mix of logistic distributions loss.

	Note that it is assumed that input is scaled to [-1, 1]

	Args:
		y_hat: Tensor [batch_size, channels, time_length], predicted output.
		y: Tensor [batch_size, time_length, 1], Target.
	Returns:
		Tensor loss
	'''
	with tf.control_dependencies([tf.assert_equal(tf.mod(tf.shape(y_hat)[1], 3), 0), tf.assert_equal(tf.rank(y_hat), 3)]):
		nr_mix = tf.shape(y_hat)[1] // 3

	#[Batch_size, time_length, channels]
	y_hat = tf.transpose(y_hat, [0, 2, 1])

	#unpack parameters. [batch_size, time_length, num_mixtures] x 3
	logit_probs = y_hat[:, :, :nr_mix]
	means = y_hat[:, :, nr_mix:2 * nr_mix]
	log_scales = tf.maximum(y_hat[:, :, 2* nr_mix: 3 * nr_mix], log_scale_min)

	#[batch_size, time_length, 1] -> [batch_size, time_length, num_mixtures]
	y = y * tf.ones(shape=[1, 1, nr_mix], dtype=tf.float32)

	centered_y = y - means
	inv_stdv = tf.exp(-log_scales)
	plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
	cdf_plus = tf.nn.sigmoid(plus_in)
	min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
	cdf_min = tf.nn.sigmoid(min_in)

	log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
	log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)

	#probability for all other cases
	cdf_delta = cdf_plus - cdf_min

	mid_in = inv_stdv * centered_y
	#log probability in the center of the bin, to be used in extreme cases
	#(not actually used in this code)
	log_pdf_mid = mid_in - log_scales - 2. * tf.nn.softplus(mid_in)

	log_probs = tf.where(y < -0.999, log_cdf_plus,
		tf.where(y > 0.999, log_one_minus_cdf_min,
			tf.where(cdf_delta > 1e-5,
				tf.log(tf.maximum(cdf_delta, 1e-12)),
				log_pdf_mid - np.log((num_classes - 1) / 2))))
	#log_probs = log_probs + tf.nn.log_softmax(logit_probs, -1)

	log_probs = log_probs + log_prob_from_logits(logit_probs)

	if reduce:
		return -tf.reduce_sum(log_sum_exp(log_probs))
	else:
		return -tf.expand_dims(log_sum_exp(log_probs), [-1])
开发者ID:duvtedudug,项目名称:Tacotron-2,代码行数:57,代码来源:mixture.py

示例6: step

	def step(self, x, c=None, g=None, softmax=False):
		"""Forward step

		Args:
			x: Tensor of shape [batch_size, channels, time_length], One-hot encoded audio signal.
			c: Tensor of shape [batch_size, cin_channels, time_length], Local conditioning features.
			g: Tensor of shape [batch_size, gin_channels, 1] or Ids of shape [batch_size, 1], 
				Global conditioning features.
				Note: set hparams.use_speaker_embedding to False to disable embedding layer and 
				use extrnal One-hot encoded features.
			softmax: Boolean, Whether to apply softmax.

		Returns:
			a Tensor of shape [batch_size, out_channels, time_length]
		"""
		#[batch_size, channels, time_length] -> [batch_size, time_length, channels]
		batch_size = tf.shape(x)[0]
		time_length = tf.shape(x)[-1]

		if g is not None:
			if self.embed_speakers is not None:
				#[batch_size, 1] ==> [batch_size, 1, gin_channels]
				g = self.embed_speakers(tf.reshape(g, [batch_size, -1]))
				#[batch_size, gin_channels, 1]
				with tf.control_dependencies([tf.assert_equal(tf.rank(g), 3)]):
					g = tf.transpose(g, [0, 2, 1])

		#Expand global conditioning features to all time steps
		g_bct = _expand_global_features(batch_size, time_length, g, data_format='BCT')

		if c is not None and self.upsample_conv is not None:
			#[batch_size, 1, cin_channels, time_length]
			c = tf.expand_dims(c, axis=1)
			for transposed_conv in self.upsample_conv:
				c = transposed_conv(c)

			#[batch_size, cin_channels, time_length]
			c = tf.squeeze(c, [1])
			with tf.control_dependencies([tf.assert_equal(tf.shape(c)[-1], tf.shape(x)[-1])]):
				c = tf.identity(c, name='control_c_and_x_shape')

		#Feed data to network
		x = self.first_conv(x)
		skips = None
		for conv in self.conv_layers:
			x, h = conv(x, c, g_bct)
			if skips is None:
				skips = h
			else:
				skips = skips + h
		x = skips

		for conv in self.last_conv_layers:
			x = conv(x)

		return tf.nn.softmax(x, axis=1) if softmax else x
开发者ID:duvtedudug,项目名称:Tacotron-2,代码行数:56,代码来源:wavenet.py

示例7: _get_window

 def _get_window(window_length, dtype):
   if self._window == "hanning":
       window = tf.contrib.signal.hann_window(window_length, dtype=dtype)
   if self._window == "blackman":
       tf.assert_equal(frame_size, window_length)
       import scipy.signal
       window = tf.constant(scipy.signal.blackman(frame_size), dtype=tf.float32)
   if self._window == "None" or self._window == "ones":
     window = tf.ones((window_length,), dtype=dtype)
   return window
开发者ID:rwth-i6,项目名称:returnn,代码行数:10,代码来源:TFNetworkSigProcLayer.py

示例8: __init__

  def __init__(self, l_overwrite=None, p_overwrite=None, q_overwrite=None, filter_input=None, parameters=None, noise_estimation=None, average_parameters=False, **kwargs):
    """
    :param float|None l_overwrite: if given overwrites the l value of the parametric wiener filter with the given constant
    :param float|None p_overwrite: if given overwrites the p value of the parametric wiener filter with the given constant
    :param float|None q_overwrite: if given overwrites the q value of the parametric wiener filter with the given constant
    :param LayerBase|None filter_input: name of layer containing input for wiener filter
    :param LayerBase|None parameters: name of layer containing parameters for wiener filter
    :param LayerBase|None noise_estimation: name of layer containing noise estimate for wiener filter
    :param bool average_parameters: if set to true the parameters l, p and q are averaged over the time axis
    """
    from tfSi6Proc.audioProcessing.enhancement.singleChannel import TfParametricWienerFilter
    super(ParametricWienerFilterLayer, self).__init__(**kwargs)

    class _NoiseEstimator(object):
      def __init__(self, noise_power_spectrum_tensor):
        self._noise_power_spectrum_tensor = noise_power_spectrum_tensor

      @classmethod
      def from_layer(cls, layer):
        return cls(layer.output.get_placeholder_as_batch_major())

      def getNoisePowerSpectrum(self):
        return self._noise_power_spectrum_tensor

    def _getParametersFromConstructorInputs(parameters, l_overwrite, p_overwrite, q_overwrite, average_parameters):
      parameter_vector = None
      if parameters is not None:
        parameter_vector = parameters.output.get_placeholder_as_batch_major()
        tf.assert_equal(parameter_vector.shape[-1], 3)
      if (l_overwrite is None) or (p_overwrite is None) or (q_overwrite is None):
        assert parameter_vector is not None
        if average_parameters:
          parameter_vector= tf.tile(tf.reduce_mean(parameter_vector, axis=1, keep_dims=True), [1, tf.shape(parameter_vector)[1], 1])
      if l_overwrite is not None:
        l = tf.constant(l_overwrite, dtype=tf.float32)
      else:
        l = tf.expand_dims(parameter_vector[:, :, 0], axis=-1)
      if p_overwrite is not None:
        p = tf.constant(p_overwrite, dtype=tf.float32)
      else:
        p = tf.expand_dims(parameter_vector[:, :, 1], axis=-1)
      if q_overwrite is not None:
        q = tf.constant(q_overwrite, dtype=tf.float32)
      else:
        q = tf.expand_dims(parameter_vector[:, :, 2], axis=-1)
      return l, p, q

    filter_input_placeholder = filter_input.output.get_placeholder_as_batch_major()
    if filter_input_placeholder.dtype != tf.complex64:
      filter_input_placeholder = tf.cast(filter_input_placeholder, dtype=tf.complex64)
    tf.assert_equal(noise_estimation.output.get_placeholder_as_batch_major().shape[-1], filter_input_placeholder.shape[-1])
    ne = _NoiseEstimator.from_layer(noise_estimation)
    l, p, q = _getParametersFromConstructorInputs(parameters, l_overwrite, p_overwrite, q_overwrite, average_parameters)
    wiener = TfParametricWienerFilter(ne, [], l, p, q, inputTensorFreqDomain=filter_input_placeholder)
    self.output.placeholder = wiener.getFrequencyDomainOutputSignal()
开发者ID:rwth-i6,项目名称:returnn,代码行数:55,代码来源:TFNetworkSigProcLayer.py

示例9: compute_loss

  def compute_loss(self, unreduced_loss):
    """Computes scaled loss based on mask out size."""
    # construct mask to identify zero padding that was introduced to
    # make the batch rectangular
    batch_duration = tf.shape(self.pianorolls)[1]
    indices = tf.to_float(tf.range(batch_duration))
    pad_mask = tf.to_float(
        indices[None, :, None, None] < self.lengths[:, None, None, None])

    # construct mask and its complement, respecting pad mask
    mask = pad_mask * self.masks
    unmask = pad_mask * (1. - self.masks)

    # Compute numbers of variables
    # #timesteps * #variables per timestep
    variable_axis = 3 if self.hparams.use_softmax_loss else 2
    dd = (
        self.lengths[:, None, None, None] * tf.to_float(
            tf.shape(self.pianorolls)[variable_axis]))
    reduced_dd = tf.reduce_sum(dd)

    # Compute numbers of variables to be predicted/conditioned on
    mask_size = tf.reduce_sum(mask, axis=[1, variable_axis], keep_dims=True)
    unmask_size = tf.reduce_sum(unmask, axis=[1, variable_axis], keep_dims=True)

    unreduced_loss *= pad_mask
    if self.hparams.rescale_loss:
      unreduced_loss *= dd / mask_size

    # Compute average loss over entire set of variables
    self.loss_total = tf.reduce_sum(unreduced_loss) / reduced_dd

    # Compute separate losses for masked/unmasked variables
    # NOTE: indexing the pitch dimension with 0 because the mask is constant
    # across pitch. Except in the sigmoid case, but then the pitch dimension
    # will have been reduced over.
    self.reduced_mask_size = tf.reduce_sum(mask_size[:, :, 0, :])
    self.reduced_unmask_size = tf.reduce_sum(unmask_size[:, :, 0, :])

    assert_partition_op = tf.group(
        tf.assert_equal(tf.reduce_sum(mask * unmask), 0.),
        tf.assert_equal(self.reduced_mask_size + self.reduced_unmask_size,
                        reduced_dd))
    with tf.control_dependencies([assert_partition_op]):
      self.loss_mask = (
          tf.reduce_sum(mask * unreduced_loss) / self.reduced_mask_size)
      self.loss_unmask = (
          tf.reduce_sum(unmask * unreduced_loss) / self.reduced_unmask_size)

    # Check which loss to use as objective function.
    self.loss = (
        self.loss_mask if self.hparams.optimize_mask_only else self.loss_total)
开发者ID:czhuang,项目名称:magenta-autofill,代码行数:52,代码来源:lib_graph.py

示例10: _kl_independent

def _kl_independent(a, b, name="kl_independent"):
  """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default "kl_independent".

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
  p = a.distribution
  q = b.distribution

  # The KL between any two (non)-batched distributions is a scalar.
  # Given that the KL between two factored distributions is the sum, i.e.
  # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
  # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
  if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined():
    if a.event_shape == b.event_shape:
      if p.event_shape == q.event_shape:
        num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims
        reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

        return tf.reduce_sum(
            kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
      else:
        raise NotImplementedError("KL between Independents with different "
                                  "event shapes not supported.")
    else:
      raise ValueError("Event shapes do not match.")
  else:
    with tf.control_dependencies([
        tf.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()),
        tf.assert_equal(p.event_shape_tensor(), q.event_shape_tensor())
    ]):
      num_reduce_dims = (
          tf.shape(a.event_shape_tensor()[0]) - tf.shape(
              p.event_shape_tensor()[0]))
      reduce_dims = tf.range(-num_reduce_dims - 1, -1, 1)
      return tf.reduce_sum(
          kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
开发者ID:asudomoeva,项目名称:probability,代码行数:52,代码来源:independent.py

示例11: _build_clp_multiplication

 def _build_clp_multiplication(self, clp_kernel):
   from TFUtil import safe_log
   input_placeholder = self.input_data.get_placeholder_as_batch_major()
   tf.assert_equal(tf.shape(clp_kernel)[1], tf.shape(input_placeholder)[2] // 2)
   tf.assert_equal(tf.shape(clp_kernel)[2], self._nr_of_filters)
   input_real = tf.strided_slice(input_placeholder, [0, 0, 0], tf.shape(input_placeholder), [1, 1, 2])
   input_imag = tf.strided_slice(input_placeholder, [0, 0, 1], tf.shape(input_placeholder), [1, 1, 2])
   kernel_real = self._clp_kernel[0, :, :]
   kernel_imag = self._clp_kernel[1, :, :]
   output_real = tf.einsum('btf,fp->btp', input_real, kernel_real) - tf.einsum('btf,fp->btp', input_imag, kernel_imag)
   output_imag = tf.einsum('btf,fp->btp', input_imag, kernel_real) + tf.einsum('btf,fp->btp', input_real, kernel_imag)
   output_uncompressed = tf.sqrt(tf.pow(output_real, 2) + tf.pow(output_imag, 2))
   output_compressed = safe_log(output_uncompressed)
   return output_compressed
开发者ID:rwth-i6,项目名称:returnn,代码行数:14,代码来源:TFNetworkSigProcLayer.py

示例12: test_doesnt_raise_when_both_empty

 def test_doesnt_raise_when_both_empty(self):
   with self.test_session():
     larry = tf.constant([])
     curly = tf.constant([])
     with tf.control_dependencies([tf.assert_equal(larry, curly)]):
       out = tf.identity(larry)
     out.eval()
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py

示例13: sample_from_discretized_mix_logistic

def sample_from_discretized_mix_logistic(y, log_scale_min=-7.):
	'''
	Args:
		y: Tensor, [batch_size, channels, time_length]
	Returns:
		Tensor: sample in range of [-1, 1]
	'''
	with tf.control_dependencies([tf.assert_equal(tf.mod(tf.shape(y)[1], 3), 0)]):
		nr_mix = tf.shape(y)[1] // 3

	#[batch_size, time_length, channels]
	y = tf.transpose(y, [0, 2, 1])
	logit_probs = y[:, :, :nr_mix]

	#sample mixture indicator from softmax
	temp = tf.random_uniform(tf.shape(logit_probs), minval=1e-5, maxval=1. - 1e-5)
	temp = logit_probs - tf.log(-tf.log(temp))
	argmax = tf.argmax(temp, -1)

	#[batch_size, time_length] -> [batch_size, time_length, nr_mix]
	one_hot = tf.one_hot(argmax, depth=nr_mix, dtype=tf.float32)
	#select logistic parameters
	means = tf.reduce_sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
	log_scales = tf.maximum(tf.reduce_sum(
		y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), log_scale_min)

	#sample from logistic & clip to interval
	#we don't actually round to the nearest 8-bit value when sampling
	u = tf.random_uniform(tf.shape(means), minval=1e-5, maxval=1. - 1e-5)
	x = means + tf.exp(log_scales) * (tf.log(u) - tf.log(1 -u))

	return tf.minimum(tf.maximum(x, -1.), 1.)
开发者ID:duvtedudug,项目名称:Tacotron-2,代码行数:32,代码来源:mixture.py

示例14: _maybe_validate_perm

def _maybe_validate_perm(perm, validate_args, name=None):
  """Checks that `perm` is valid."""
  with tf.name_scope(name, 'maybe_validate_perm', [perm]):
    assertions = []
    if not perm.dtype.is_integer:
      raise TypeError('`perm` must be integer type')

    msg = '`perm` must be a vector.'
    if perm.shape.ndims is not None:
      if perm.shape.ndims != 1:
        raise ValueError(
            msg[:-1] + ', saw rank: {}.'.format(perm.shape.ndims))
    elif validate_args:
      assertions += [tf.assert_rank(perm, 1, message=msg)]

    perm_ = tf.contrib.util.constant_value(perm)
    msg = '`perm` must be a valid permutation vector.'
    if perm_ is not None:
      if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
        raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
    elif validate_args:
      assertions += [tf.assert_equal(
          tf.contrib.framework.sort(perm),
          tf.range(tf.size(perm)),
          message=msg)]

    return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:27,代码来源:transpose.py

示例15: zero_state

    def zero_state(self, batch_size, dtype):
        with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "zero_state of AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.")
            with tf.control_dependencies(
                [tf.assert_equal(batch_size,
                    self._attention_mechanism.batch_size,
                    message=error_message)]):
                cell_state = nest.map_structure(
                    lambda s: tf.identity(s, name="checked_cell_state"),
                    cell_state)
            alignment_history = ()

            _zero_state_tensors = rnn_cell_impl._zero_state_tensors
            return AttentionWrapperState(
                cell_state=cell_state,
                time=tf.zeros([], dtype=tf.int32),
                attention=_zero_state_tensors(self._attention_size, batch_size,
                dtype),
                alignments=self._attention_mechanism.initial_alignments(
                    batch_size, dtype),
                alignment_history=alignment_history)
开发者ID:laurii,项目名称:DeepChatModels,代码行数:28,代码来源:_rnn.py


注:本文中的tensorflow.assert_equal函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。