本文整理汇总了Python中tensorflow.batch_gather方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.batch_gather方法的具体用法?Python tensorflow.batch_gather怎么用?Python tensorflow.batch_gather使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.batch_gather方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _call
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _call(self, inputs):
eps = 0.001
ids, num_samples, features, batch_size = inputs
adj_lists = tf.gather(self.adj_info, ids)
node_features = tf.gather(features, ids)
feature_size = tf.shape(features)[-1]
node_feature_repeat = tf.tile(node_features, [1,self.num_neighs])
node_feature_repeat = tf.reshape(node_feature_repeat, [batch_size, self.num_neighs, feature_size])
neighbor_feature = tf.gather(features, adj_lists)
distance = tf.sqrt(tf.reduce_sum(tf.square(node_feature_repeat - neighbor_feature), -1))
prob = tf.exp(-distance)
prob_sum = tf.reduce_sum(prob, -1, keepdims=True)
prob_sum = tf.tile(prob_sum, [1,self.num_neighs])
prob = tf.divide(prob, prob_sum)
prob = tf.where(prob>eps, prob, 0*prob) # uncommenting this line to use eps to filter small probabilities
samples_idx = tf.random.categorical(tf.math.log(prob), num_samples)
selected = tf.batch_gather(adj_lists, samples_idx)
return selected
示例2: gather
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def gather(self, data, pl_idx, pl_mask, max_len, name=None):
"""
Lookup equivalent for tensors with dim > 2 (Can be simplified using tf.batch_gather)
Parameters
----------
data: Tensor in which lookup has to be performed
pl_idx: The indices to be taken
pl_mask: For handling padding in pl_idx
max_len: Maximum length of indices
Returns
-------
et_vecs * mask_vec: Extracted vectors at given indices
"""
idx1 = tf.range(self.p.batch_size, dtype=tf.int32)
idx1 = tf.reshape(idx1, [-1, 1])
idx1_ = tf.reshape(tf.tile(idx1, [1, max_len]) , [-1, 1])
idx_reshape = tf.reshape(pl_idx, [-1, 1])
indices = tf.concat((idx1_, idx_reshape), axis=1)
et_vecs = tf.gather_nd(data, indices)
et_vecs = tf.reshape(et_vecs, [self.p.batch_size, self.max_et, -1])
mask_vec = tf.expand_dims(pl_mask, axis=2)
return et_vecs * mask_vec
示例3: _interpolate
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _interpolate(self, xy1, xy2, points2):
batch_size = tf.shape(xy1)[0]
ndataset1 = tf.shape(xy1)[1]
eps = 1e-6
dist_mat = tf.matmul(xy1, xy2, transpose_b=True)
norm1 = tf.reduce_sum(xy1 * xy1, axis=-1, keepdims=True)
norm2 = tf.reduce_sum(xy2 * xy2, axis=-1, keepdims=True)
dist_mat = tf.sqrt(norm1 - 2 * dist_mat + tf.linalg.matrix_transpose(norm2) + eps)
dist, idx = tf.math.top_k(tf.negative(dist_mat), k=3)
dist = tf.maximum(dist, 1e-10)
norm = tf.reduce_sum((1.0 / dist), axis=2, keepdims=True)
norm = tf.tile(norm, [1, 1, 3])
weight = (1.0 / dist) / norm
idx = tf.reshape(idx, (batch_size, -1))
nn_points = tf.batch_gather(points2, idx)
nn_points = tf.reshape(nn_points, (batch_size, ndataset1, 3, points2.get_shape()[-1].value))
interpolated_points = tf.reduce_sum(weight[..., tf.newaxis] * nn_points, axis=-2)
return interpolated_points
示例4: nucleus_sampling
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def nucleus_sampling(logits, vocab_size, p=0.9,
input_ids=None, input_ori_ids=None,
**kargs):
input_shape_list = bert_utils.get_shape_list(logits, expected_rank=[2,3])
if len(input_shape_list) == 3:
logits = tf.reshape(logits, (-1, vocab_size))
probs = tf.nn.softmax(logits, axis=-1)
# [batch_size, seq, vocab_perm]
# indices = tf.argsort(probs, direction='DESCENDING')
indices = tf.contrib.framework.argsort(probs, direction='DESCENDING')
cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
# find the top pth index to cut off. careful we don't want to cutoff everything!
# result will be [batch_size, seq, vocab_perm]
exclude_mask = tf.logical_not(
tf.logical_or(cumulative_probabilities < p, tf.range(vocab_size)[None] < 1))
exclude_mask = tf.cast(exclude_mask, tf.float32)
indices_v1 = tf.contrib.framework.argsort(indices)
exclude_mask = reorder(exclude_mask, tf.cast(indices_v1, dtype=tf.int32))
if len(input_shape_list) == 3:
exclude_mask = tf.reshape(exclude_mask, input_shape_list)
# logits = tf.reshape(logits, input_shape_list)
if input_ids is not None and input_ori_ids is not None:
exclude_mask, input_ori_ids = get_extra_mask(
input_ids, input_ori_ids,
exclude_mask, vocab_size,
**kargs)
return [exclude_mask, input_ori_ids]
else:
return [exclude_mask]
示例5: _top_k_sample
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
"""
Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
:param logits: [batch_size, vocab_size] tensor
:param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
like padding maybe
:param p: topp threshold to use, either a float or a [batch_size] vector
:return: [batch_size, num_samples] samples
# TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
"""
with tf.variable_scope('top_p_sample'):
batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
axis=-1)
# [batch_size, vocab_perm]
indices = tf.argsort(probs, direction='DESCENDING')
# find the top pth index to cut off. careful we don't want to cutoff everything!
# result will be [batch_size, vocab_perm]
k_expanded = k if isinstance(k, int) else k[:, None]
exclude_mask = tf.range(vocab_size)[None] >= k_expanded
# OPTION A - sample in the sorted space, then unsort.
logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
sample = tf.batch_gather(indices, sample_perm)
return {
'probs': probs,
'sample': sample,
}
示例6: batch_gather
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def batch_gather(params, indices):
"""同tf旧版本的batch_gather
"""
try:
return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1)
except Exception as e1:
try:
return tf.batch_gather(params, indices)
except Exception as e2:
raise ValueError('%s\n%s\n' % (e1.message, e2.message))
示例7: _top_p_sample
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
"""
Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
:param logits: [batch_size, vocab_size] tensor
:param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
like padding maybe
:param p: topp threshold to use, either a float or a [batch_size] vector
:return: [batch_size, num_samples] samples
# TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
"""
with tf.variable_scope('top_p_sample'):
batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
axis=-1)
if isinstance(p, float) and p > 0.999999:
# Don't do top-p sampling in this case
print("Top-p sampling DISABLED", flush=True)
return {
'probs': probs,
'sample': tf.random.categorical(
logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
num_samples=num_samples, dtype=tf.int32),
}
# [batch_size, vocab_perm]
indices = tf.argsort(probs, direction='DESCENDING')
cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
# find the top pth index to cut off. careful we don't want to cutoff everything!
# result will be [batch_size, vocab_perm]
p_expanded = p if isinstance(p, float) else p[:, None]
exclude_mask = tf.logical_not(
tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))
# OPTION A - sample in the sorted space, then unsort.
logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
sample = tf.batch_gather(indices, sample_perm)
# OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
# unperm_indices = tf.argsort(indices, direction='ASCENDING')
# include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
# logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
# sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)
return {
'probs': probs,
'sample': sample,
}
示例8: sample_step
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
"""
Helper function that samples from grover for a single step
:param tokens: [batch_size, n_ctx_b] tokens that we will predict from
:param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
:param news_config: config for the GroverModel
:param batch_size: batch size to use
:param p_for_topp: top-p or top-k threshold
:param cache: [batch_size, news_config.num_hidden_layers, 2,
news_config.num_attention_heads, n_ctx_a,
news_config.hidden_size // news_config.num_attention_heads] OR, None
:return: new_tokens, size [batch_size]
new_probs, also size [batch_size]
new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
"""
model = GroverModel(
config=news_config,
is_training=False,
input_ids=tokens,
reuse=tf.AUTO_REUSE,
scope='newslm',
chop_off_last_token=False,
do_cache=True,
cache=cache,
)
# Extract the FINAL SEQ LENGTH
batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]
if do_topk:
sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
else:
sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)
new_tokens = tf.squeeze(sample_info['sample'], 1)
new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
return {
'new_tokens': new_tokens,
'new_probs': new_probs,
'new_cache': model.new_kvs,
}
示例9: fast_tpu_gather
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def fast_tpu_gather(params, indices, name=None):
"""Fast gather implementation for models running on TPU.
This function use one_hot and batch matmul to do gather, which is faster
than gather_nd on TPU. For params that have dtype of int32 (sequences to
gather from), batch_gather is used to keep accuracy.
Args:
params: A tensor from which to gather values.
[batch_size, original_size, ...]
indices: A tensor used as the index to gather values.
[batch_size, selected_size].
name: A string, name of the operation (optional).
Returns:
gather_result: A tensor that has the same rank as params.
[batch_size, selected_size, ...]
"""
with tf.name_scope(name):
dtype = params.dtype
def _gather(params, indices):
"""Fast gather using one_hot and batch matmul."""
if dtype != tf.float32:
params = tf.to_float(params)
shape = common_layers.shape_list(params)
indices_shape = common_layers.shape_list(indices)
ndims = params.shape.ndims
# Adjust the shape of params to match one-hot indices, which is the
# requirement of Batch MatMul.
if ndims == 2:
params = tf.expand_dims(params, axis=-1)
if ndims > 3:
params = tf.reshape(params, [shape[0], shape[1], -1])
gather_result = tf.matmul(
tf.one_hot(indices, shape[1], dtype=params.dtype), params)
if ndims == 2:
gather_result = tf.squeeze(gather_result, axis=-1)
if ndims > 3:
shape[1] = indices_shape[1]
gather_result = tf.reshape(gather_result, shape)
if dtype != tf.float32:
gather_result = tf.cast(gather_result, dtype)
return gather_result
# If the dtype is int, use the gather instead of one_hot matmul to avoid
# precision loss. The max int value can be represented by bfloat16 in MXU is
# 256, which is smaller than the possible id values. Encoding/decoding can
# potentially used to make it work, but the benenfit is small right now.
if dtype.is_integer:
gather_result = tf.batch_gather(params, indices)
else:
gather_result = _gather(params, indices)
return gather_result
示例10: fast_tpu_gather
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def fast_tpu_gather(params, indices, name=None):
"""Fast gather implementation for models running on TPU.
This function use one_hot and batch matmul to do gather, which is faster
than gather_nd on TPU. For params that have dtype of int32 (sequences to
gather from), batch_gather is used to keep accuracy.
Args:
params: A tensor from which to gather values.
[batch_size, original_size, ...]
indices: A tensor used as the index to gather values.
[batch_size, selected_size].
name: A string, name of the operation (optional).
Returns:
gather_result: A tensor that has the same rank as params.
[batch_size, selected_size, ...]
"""
with tf.name_scope(name):
dtype = params.dtype
def _gather(params, indices):
"""Fast gather using one_hot and batch matmul."""
if dtype != tf.float32:
params = tf.to_float(params)
shape = common_layers.shape_list(params)
indices_shape = common_layers.shape_list(indices)
ndims = params.shape.ndims
# Adjust the shape of params to match one-hot indices, which is the
# requirement of Batch MatMul.
if ndims == 2:
params = tf.expand_dims(params, axis=-1)
if ndims > 3:
params = tf.reshape(params, [shape[0], shape[1], -1])
gather_result = tf.matmul(
tf.one_hot(indices, shape[1], dtype=params.dtype), params)
if ndims == 2:
gather_result = tf.squeeze(gather_result, axis=-1)
if ndims > 3:
shape[1] = indices_shape[1]
gather_result = tf.reshape(gather_result, shape)
if dtype != tf.float32:
gather_result = tf.cast(gather_result, dtype)
return gather_result
# If the dtype is int32, use the gather instead of one_hot matmul to avoid
# precision loss. The max int value can be represented by bfloat16 in MXU is
# 256, which is smaller than the possible id values. Encoding/decoding can
# potentially used to make it work, but the benenfit is small right now.
if dtype == tf.int32:
gather_result = tf.batch_gather(params, indices)
else:
gather_result = _gather(params, indices)
return gather_result
示例11: gconv
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def gconv(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]):
if compute_graph:
D = self.compute_graph(h)
_, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (B, N, d+1)
top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,:,0],2), [1, 1, self.config.min_nn-8]), [-1, self.N*(self.config.min_nn-8)]) # (B, N*d)
top_idx = tf.reshape(top_idx[:,:,9:],[-1, self.N*(self.config.min_nn-8)]) # (B, N*d)
x_tilde1 = tf.batch_gather(h, top_idx) # (B, K, dlm1)
x_tilde2 = tf.batch_gather(h, top_idx2) # (B, K, dlm1)
labels = x_tilde1 - x_tilde2 # (B, K, dlm1)
x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (B*K, dlm1)
labels = tf.reshape(labels, [-1, in_feat]) # (B*K, dlm1)
d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (B*N, d)
name_flayer = name + "_flayer0"
labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer]) # (B*K, F)
name_flayer = name + "_flayer1"
labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F)
labels1 = labels_exp+0.0
for ss in range(1, in_feat/stride_th1):
labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1)
labels2 = labels_exp+0.0
for ss in range(1, out_feat/stride_th2):
labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1)
theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] ) # (B*K*dlm1/stride, R*stride)
theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"]
theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] ) # (B*K*dl/stride, R*stride)
theta2 = tf.reshape(theta2, [-1, self.config.rank_theta, out_feat] ) + self.b[name_flayer+"_th2"]
thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1)
x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (B*K, R, 1)
x = tf.multiply(x, thetal) # (B*K, R, 1)
x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (B*K, dl)
x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl)
x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl)
x = tf.reduce_mean(x, 1) # (N, dl)
x = tf.reshape(x,[-1, self.N, out_feat]) # (B, N, dl)
if return_graph:
return x, D
else:
return x
示例12: gconv_conv_inner
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import batch_gather [as 别名]
def gconv_conv_inner(self, h, name, in_feat, out_feat, stride_th1, stride_th2, compute_graph=True, return_graph=False, D=[]):
h = tf.expand_dims(h, 0) # (1,M,dl)
p = tf.image.extract_image_patches(h, ksizes=[1, self.config.search_window[0], self.config.search_window[1], 1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID") # (1,X,Y,dlm1*W)
p = tf.reshape(p,[-1, self.config.search_window[0], self.config.search_window[1], in_feat])
p = tf.reshape(p,[-1, self.config.searchN, in_feat]) # (N,W,dlm1)
if compute_graph:
D = tf.map_fn(lambda feat: self.gconv_conv_inner2(feat), tf.reshape(p,[self.config.search_window[0],self.config.search_window[1],self.config.searchN, in_feat]), parallel_iterations=16, swap_memory=False) # (B,N/B,W)
D = tf.reshape(D,[-1, self.config.searchN]) # (N,W)
_, top_idx = tf.nn.top_k(-D, self.config.min_nn+1) # (N, d+1)
#top_idx2 = tf.reshape(tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn[i]]), [-1])
top_idx2 = tf.tile(tf.expand_dims(top_idx[:,0],1), [1, self.config.min_nn-8]) # (N, d)
#top_idx = tf.reshape(top_idx[:,1:],[-1]) # (N*d,)
top_idx = top_idx[:,9:] # (N, d)
x_tilde1 = tf.batch_gather(p, top_idx) # (N, d, dlm1)
x_tilde1 = tf.reshape(x_tilde1, [-1, in_feat]) # (K, dlm1)
x_tilde2 = tf.batch_gather(p, top_idx2) # (N, d, dlm1)
x_tilde2 = tf.reshape(x_tilde2, [-1, in_feat]) # (K, dlm1)
labels = x_tilde1 - x_tilde2 # (K, dlm1)
d_labels = tf.reshape( tf.reduce_sum(labels*labels, 1), [-1, self.config.min_nn-8]) # (N, d)
name_flayer = name + "_flayer0"
labels = tf.nn.leaky_relu(tf.matmul(labels, self.W[name_flayer]) + self.b[name_flayer])
name_flayer = name + "_flayer1"
labels_exp = tf.expand_dims(labels, 1) # (B*K, 1, F)
labels1 = labels_exp+0.0
for ss in range(1, in_feat/stride_th1):
labels1 = tf.concat( [labels1, self.myroll(labels_exp, shift=(ss+1)*stride_th1, axis=2)], axis=1 ) # (B*K, dlm1/stride, dlm1)
labels2 = labels_exp+0.0
for ss in range(1, out_feat/stride_th2):
labels2 = tf.concat( [labels2, self.myroll(labels_exp, shift=(ss+1)*stride_th2, axis=2)], axis=1 ) # (B*K, dl/stride, dlm1)
theta1 = tf.matmul( tf.reshape(labels1, [-1, in_feat]), self.W[name_flayer+"_th1"] ) # (B*K*dlm1/stride, R*stride)
theta1 = tf.reshape(theta1, [-1, self.config.rank_theta, in_feat] ) + self.b[name_flayer+"_th1"]
theta2 = tf.matmul( tf.reshape(labels2, [-1, in_feat]), self.W[name_flayer+"_th2"] ) # (B*K*dl/stride, R*stride)
theta2 = tf.reshape(theta2, [-1, self.config.rank_theta, out_feat] ) + self.b[name_flayer+"_th2"]
thetal = tf.expand_dims( tf.matmul(labels, self.W[name_flayer+"_thl"]) + self.b[name_flayer+"_thl"], 2 ) # (B*K, R, 1)
x = tf.matmul(theta1, tf.expand_dims(x_tilde1,2)) # (K, R, 1)
x = tf.multiply(x, thetal) # (K, R, 1)
x = tf.matmul(theta2, x, transpose_a=True)[:,:,0] # (K, dl)
x = tf.reshape(x, [-1, self.config.min_nn-8, out_feat]) # (N, d, dl)
x = tf.multiply(x, tf.expand_dims(tf.exp(-tf.div(d_labels,10)),2)) # (N, d, dl)
x = tf.reduce_mean(x, 1) # (N, dl)
x = tf.expand_dims(x,0) # (1, N, dl)
return [x, D]