当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.batch_gather方法代码示例

本文整理汇总了Python中tensorflow.batch_gather方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.batch_gather方法的具体用法?Python tensorflow.batch_gather怎么用?Python tensorflow.batch_gather使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.batch_gather方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _call

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _call(self, inputs):
        eps = 0.001
        ids, num_samples, features, batch_size = inputs
        adj_lists = tf.gather(self.adj_info, ids)
        node_features = tf.gather(features, ids)
        feature_size = tf.shape(features)[-1]
        node_feature_repeat = tf.tile(node_features, [1,self.num_neighs])
        node_feature_repeat = tf.reshape(node_feature_repeat, [batch_size, self.num_neighs, feature_size])
        neighbor_feature =  tf.gather(features, adj_lists)
        distance = tf.sqrt(tf.reduce_sum(tf.square(node_feature_repeat - neighbor_feature), -1))
        prob = tf.exp(-distance)
        prob_sum = tf.reduce_sum(prob, -1, keepdims=True)
        prob_sum = tf.tile(prob_sum, [1,self.num_neighs])
        prob = tf.divide(prob, prob_sum)
        prob = tf.where(prob>eps, prob, 0*prob) # uncommenting this line to use eps to filter small probabilities
        samples_idx = tf.random.categorical(tf.math.log(prob), num_samples)
        selected = tf.batch_gather(adj_lists, samples_idx)
        return selected 
开发者ID:safe-graph,项目名称:DGFraud,代码行数:20,代码来源:neigh_samplers.py

示例2: gather

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def gather(self, data, pl_idx, pl_mask, max_len, name=None):
		"""
		Lookup equivalent for tensors with dim > 2 (Can be simplified using tf.batch_gather)

		Parameters
		----------
		data:		Tensor in which lookup has to be performed
		pl_idx:		The indices to be taken
		pl_mask:	For handling padding in pl_idx
		max_len:	Maximum length of indices

		Returns
		-------
		et_vecs * mask_vec:	Extracted vectors at given indices
		
		"""
		idx1  = tf.range(self.p.batch_size, dtype=tf.int32)
		idx1  = tf.reshape(idx1, [-1, 1])
		idx1_ = tf.reshape(tf.tile(idx1, [1, max_len]) , [-1, 1])
		idx_reshape = tf.reshape(pl_idx, [-1, 1])
		indices = tf.concat((idx1_, idx_reshape), axis=1)
		et_vecs = tf.gather_nd(data, indices)
		et_vecs = tf.reshape(et_vecs, [self.p.batch_size, self.max_et, -1])
		mask_vec = tf.expand_dims(pl_mask, axis=2)
		return et_vecs * mask_vec 
开发者ID:malllabiisc,项目名称:NeuralDater,代码行数:27,代码来源:neural_dater.py

示例3: _interpolate

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _interpolate(self, xy1, xy2, points2):
        batch_size = tf.shape(xy1)[0]
        ndataset1 = tf.shape(xy1)[1]

        eps = 1e-6
        dist_mat = tf.matmul(xy1, xy2, transpose_b=True)
        norm1 = tf.reduce_sum(xy1 * xy1, axis=-1, keepdims=True)
        norm2 = tf.reduce_sum(xy2 * xy2, axis=-1, keepdims=True)
        dist_mat = tf.sqrt(norm1 - 2 * dist_mat + tf.linalg.matrix_transpose(norm2) + eps)
        dist, idx = tf.math.top_k(tf.negative(dist_mat), k=3)

        dist = tf.maximum(dist, 1e-10)
        norm = tf.reduce_sum((1.0 / dist), axis=2, keepdims=True)
        norm = tf.tile(norm, [1, 1, 3])
        weight = (1.0 / dist) / norm
        idx = tf.reshape(idx, (batch_size, -1))
        nn_points = tf.batch_gather(points2, idx)
        nn_points = tf.reshape(nn_points, (batch_size, ndataset1, 3, points2.get_shape()[-1].value))
        interpolated_points = tf.reduce_sum(weight[..., tf.newaxis] * nn_points, axis=-2)

        return interpolated_points 
开发者ID:luigifreda,项目名称:pyslam,代码行数:23,代码来源:augdesc.py

示例4: nucleus_sampling

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def nucleus_sampling(logits, vocab_size, p=0.9, 
					input_ids=None, input_ori_ids=None,
					**kargs):
	input_shape_list = bert_utils.get_shape_list(logits, expected_rank=[2,3])
	if len(input_shape_list) == 3:
		logits = tf.reshape(logits, (-1, vocab_size))
	probs = tf.nn.softmax(logits, axis=-1)
	# [batch_size, seq, vocab_perm]
	# indices = tf.argsort(probs, direction='DESCENDING')
	indices = tf.contrib.framework.argsort(probs, direction='DESCENDING')

	cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
	
	# find the top pth index to cut off. careful we don't want to cutoff everything!
	# result will be [batch_size, seq, vocab_perm]
	exclude_mask = tf.logical_not(
	tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1))
	exclude_mask = tf.cast(exclude_mask, tf.float32)

	indices_v1 = tf.contrib.framework.argsort(indices)
	exclude_mask = reorder(exclude_mask, tf.cast(indices_v1, dtype=tf.int32))
	if len(input_shape_list) == 3:
		exclude_mask = tf.reshape(exclude_mask, input_shape_list)
		# logits = tf.reshape(logits, input_shape_list)

	if input_ids is not None and input_ori_ids is not None:
		exclude_mask, input_ori_ids = get_extra_mask(
								input_ids, input_ori_ids, 
								exclude_mask, vocab_size,
								**kargs)

		return [exclude_mask, input_ori_ids]
	else:
		return [exclude_mask] 
开发者ID:yyht,项目名称:BERT,代码行数:36,代码来源:nuelus_sampling_utils.py

示例5: _top_k_sample

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
    """
    Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)
        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        k_expanded = k if isinstance(k, int) else k[:, None]
        exclude_mask = tf.range(vocab_size)[None] >= k_expanded

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

    return {
        'probs': probs,
        'sample': sample,
    } 
开发者ID:rowanz,项目名称:grover,代码行数:35,代码来源:modeling.py

示例6: batch_gather

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def batch_gather(params, indices):
    """同tf旧版本的batch_gather
    """
    try:
        return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1)
    except Exception as e1:
        try:
            return tf.batch_gather(params, indices)
        except Exception as e2:
            raise ValueError('%s\n%s\n' % (e1.message, e2.message)) 
开发者ID:bojone,项目名称:bert4keras,代码行数:12,代码来源:backend.py

示例7: _top_p_sample

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
    """
    Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)

        if isinstance(p, float) and p > 0.999999:
            # Don't do top-p sampling in this case
            print("Top-p sampling DISABLED", flush=True)
            return {
                'probs': probs,
                'sample': tf.random.categorical(
                    logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                    num_samples=num_samples, dtype=tf.int32),
            }

        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')
        cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        p_expanded = p if isinstance(p, float) else p[:, None]
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

        # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
        # unperm_indices = tf.argsort(indices, direction='ASCENDING')
        # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
        # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
        # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)

    return {
        'probs': probs,
        'sample': sample,
    } 
开发者ID:rowanz,项目名称:grover,代码行数:54,代码来源:modeling.py

示例8: sample_step

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
    """
    Helper function that samples from grover for a single step
    :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
    :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
    :param news_config: config for the GroverModel
    :param batch_size: batch size to use
    :param p_for_topp: top-p or top-k threshold
    :param cache: [batch_size, news_config.num_hidden_layers, 2,
                   news_config.num_attention_heads, n_ctx_a,
                   news_config.hidden_size // news_config.num_attention_heads] OR, None
    :return: new_tokens, size [batch_size]
             new_probs, also size [batch_size]
             new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
                   news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
    """
    model = GroverModel(
        config=news_config,
        is_training=False,
        input_ids=tokens,
        reuse=tf.AUTO_REUSE,
        scope='newslm',
        chop_off_last_token=False,
        do_cache=True,
        cache=cache,
    )

    # Extract the FINAL SEQ LENGTH
    batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
    next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]

    if do_topk:
        sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
    else:
        sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)

    new_tokens = tf.squeeze(sample_info['sample'], 1)
    new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
    return {
        'new_tokens': new_tokens,
        'new_probs': new_probs,
        'new_cache': model.new_kvs,
    } 
开发者ID:rowanz,项目名称:grover,代码行数:45,代码来源:modeling.py

示例9: fast_tpu_gather

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype.is_integer:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
开发者ID:yyht,项目名称:BERT,代码行数:57,代码来源:beam_search.py

示例10: fast_tpu_gather

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int32, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype == tf.int32:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:57,代码来源:beam_search.py

示例11: gconv

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def gconv(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]):

		if compute_graph:
			D = self.compute_graph(h)

		_, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (B, N, d+1)
		top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,:,0],2), [1, 1, self.config.min_nn-8]), [-1, self.N*(self.config.min_nn-8)]) # (B, N*d)
		top_idx = tf.reshape(top_idx[:,:,9:],[-1, self.N*(self.config.min_nn-8)]) # (B, N*d)

		x_tilde1 = tf.batch_gather(h, top_idx) # (B, K, dlm1)		
		x_tilde2 = tf.batch_gather(h, top_idx2) # (B, K, dlm1)
		labels = x_tilde1 - x_tilde2 # (B, K, dlm1)
		x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (B*K, dlm1)
		labels = tf.reshape(labels, [-1, in_feat]) # (B*K, dlm1)
		d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (B*N, d)

		name_flayer = name + "_flayer0"
		labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer]) #  (B*K, F)
		name_flayer = name + "_flayer1"
		labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F)
		labels1 = labels_exp+0.0
		for ss in range(1, in_feat/stride_th1):
			labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1)
		labels2 = labels_exp+0.0
		for ss in range(1, out_feat/stride_th2):
			labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1)
		theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] )  # (B*K*dlm1/stride, R*stride)
		theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"]
		theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] )  # (B*K*dl/stride, R*stride)
		theta2 = tf.reshape(theta2, [-1, self.config.rank_theta,  out_feat] ) + self.b[name_flayer+"_th2"]	
		thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1)

		x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (B*K, R, 1)
		x = tf.multiply(x, thetal) # (B*K, R, 1)
		x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (B*K, dl)

		x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl)
		x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl)
		x = tf.reduce_mean(x, 1) # (N, dl)
		x = tf.reshape(x,[-1, self.N, out_feat]) # (B, N, dl)
		
		if return_graph:
			return x, D
		else:
			return x 
开发者ID:diegovalsesia,项目名称:gcdn,代码行数:47,代码来源:net.py

示例12: gconv_conv_inner

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def gconv_conv_inner(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]):

		h = tf.expand_dims(h, 0) # (1,M,dl)
		p = tf.image.extract_image_patches(h, ksizes=[1, self.config.search_window[0], self.config.search_window[1], 1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID") # (1,X,Y,dlm1*W)
		p = tf.reshape(p,[-1, self.config.search_window[0], self.config.search_window[1], in_feat]) 
		p = tf.reshape(p,[-1, self.config.searchN, in_feat]) # (N,W,dlm1)

		if compute_graph:
			D = tf.map_fn(lambda feat: self.gconv_conv_inner2(feat), tf.reshape(p,[self.config.search_window[0],self.config.search_window[1],self.config.searchN, in_feat]), parallel_iterations=16, swap_memory=False) # (B,N/B,W)
			D = tf.reshape(D,[-1, self.config.searchN]) # (N,W)

		_, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (N, d+1)
		#top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn[i]]), [-1])
		top_idx2 = tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn-8]) # (N, d)
		#top_idx = tf.reshape(top_idx[:,1:],[-1]) # (N*d,)
		top_idx = top_idx[:,9:] # (N, d)

		x_tilde1 = tf.batch_gather(p, top_idx) # (N, d, dlm1)	
		x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (K, dlm1)
		x_tilde2 = tf.batch_gather(p, top_idx2) # (N, d, dlm1)
		x_tilde2 = tf.reshape(x_tilde2, [-1, in_feat]) # (K, dlm1)

		labels = x_tilde1 - x_tilde2 # (K, dlm1)
		d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (N, d)

		name_flayer = name + "_flayer0"
		labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer])
		name_flayer = name + "_flayer1"
		labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F)
		labels1 = labels_exp+0.0
		for ss in range(1, in_feat/stride_th1):
			labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1)
		labels2 = labels_exp+0.0
		for ss in range(1, out_feat/stride_th2):
			labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1)
		theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] )  # (B*K*dlm1/stride, R*stride)
		theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"]
		theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] )  # (B*K*dl/stride, R*stride)
		theta2 = tf.reshape(theta2, [-1, self.config.rank_theta,  out_feat] ) + self.b[name_flayer+"_th2"]	
		thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1)

		x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (K, R, 1)
		x = tf.multiply(x, thetal) # (K, R, 1)
		x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (K, dl)

		x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl)
		x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl)
		x = tf.reduce_mean(x, 1) # (N, dl)

		x = tf.expand_dims(x,0) # (1, N, dl)

		return [x, D] 
开发者ID:diegovalsesia,项目名称:gcdn,代码行数:54,代码来源:net_conv2.py


注:本文中的tensorflow.batch_gather方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。