本文整理汇总了Python中tensor2tensor.layers.common_attention.local_attention_2d方法的典型用法代码示例。如果您正苦于以下问题:Python common_attention.local_attention_2d方法的具体用法?Python common_attention.local_attention_2d怎么用?Python common_attention.local_attention_2d使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensor2tensor.layers.common_attention
的用法示例。
在下文中一共展示了common_attention.local_attention_2d方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testLocalUnmaskedAttention2D
# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def testLocalUnmaskedAttention2D(self, batch, heads, length,
depth_k, depth_v, query_shape):
if batch is None:
batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
q = tf.random_normal([batch, heads, length, length, depth_k])
k = tf.random_normal([batch, heads, length, length, depth_k])
v = tf.random_normal([batch, heads, length, length, depth_v])
output = common_attention.local_attention_2d(
q,
k,
v,
query_shape=query_shape,
memory_flange=(3, 3))
if isinstance(batch, tf.Tensor):
batch, res = self.evaluate([batch, output])
else:
res = self.evaluate(output)
self.assertEqual(res.shape, (batch, heads, length, length, depth_v))
示例2: load_model
# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def load_model(model_path):
custom_layers = {
"multihead_attention": multihead_attention,
"Conv2D": L.Conv2D,
"split_heads_2d": split_heads_2d,
"local_attention_2d": local_attention_2d,
"combine_heads_2d": combine_heads_2d
}
model = model_from_yaml(open(os.path.join(model_path, "arch.yaml")).read(), custom_objects=custom_layers)
full_path = os.path.join(model_path, "weights.h5")
with h5py.File(full_path, "r") as w:
keys = list(w.keys())
is_para = any(["model" in k for k in keys])
if is_para:
para_model = multi_gpu_model(model, gpus=2)
para_model.load_weights(full_path)
model = para_model.layers[-2]
else:
model.load_weights(full_path)
print("Model " + model_path + " loaded")
return model
示例3: testLocalUnmaskedAttention2D
# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def testLocalUnmaskedAttention2D(self):
x = np.random.rand(5, 4, 25, 25, 16)
y = np.random.rand(5, 4, 25, 25, 16)
with self.test_session() as session:
a = common_attention.local_attention_2d(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
query_shape=(4, 4),
memory_flange=(3, 3))
session.run(tf.global_variables_initializer())
res = session.run(a)
self.assertEqual(res.shape, (5, 4, 25, 25, 16))
示例4: testLocalUnmaskedAttention2DMatchingBlockLength
# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def testLocalUnmaskedAttention2DMatchingBlockLength(self):
x = np.random.rand(5, 4, 25, 25, 16)
y = np.random.rand(5, 4, 25, 25, 16)
with self.test_session() as session:
a = common_attention.local_attention_2d(
tf.constant(x, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
tf.constant(y, dtype=tf.float32),
query_shape=(5, 5),
memory_flange=(3, 3))
session.run(tf.global_variables_initializer())
res = session.run(a)
self.assertEqual(res.shape, (5, 4, 25, 25, 16))
示例5: multihead_attention
# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def multihead_attention(x, out_channel=64, d_model=32, n_heads=8, query_shape=(128, 24), memory_flange=(8, 8)):
q = Conv2D(d_model, (3, 3), strides=(1, 1), padding="same", name="gen_q_conv")(x)
k = Conv2D(d_model, (3, 3), strides=(1, 1), padding="same", name="gen_k_conv")(x)
v = Conv2D(d_model, (3, 3), strides=(1, 1), padding="same", name="gen_v_conv")(x)
q = split_heads_2d(q, n_heads)
k = split_heads_2d(k, n_heads)
v = split_heads_2d(v, n_heads)
k_depth_per_head = d_model // n_heads
q *= k_depth_per_head**-0.5
"""
# local attetion 2d
v_shape = K.int_shape(v)
q = pad_to_multiple(q, query_shape)
k = pad_to_multiple(k, query_shape)
v = pad_to_multiple(v, query_shape)
paddings = ((0, 0), (memory_flange[0], memory_flange[1]), (memory_flange[0], memory_flange[1]))
k = L.ZeroPadding3D(padding=paddings)(k)
v = L.ZeroPadding3D(padding=paddings)(v)
# Set up query blocks
q_indices = gather_indices_2d(q, query_shape, query_shape)
q_new = gather_blocks_2d(q, q_indices)
# Set up key and value blocks
memory_shape = (query_shape[0] + 2*memory_flange[0],
query_shape[1] + 2*memory_flange[1])
k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape)
k_new = gather_blocks_2d(k, k_and_v_indices)
v_new = gather_blocks_2d(v, k_and_v_indices)
output = dot_attention(q_new, k_new, v_new)
# Put output back into original shapes
padded_shape = K.shape(q)
output = scatter_blocks_2d(output, q_indices, padded_shape)
# Remove padding
output = K.slice(output, [0, 0, 0, 0, 0], [-1, -1, v_shape[2], v_shape[3], -1])
"""
output = local_attention_2d(q, k, v, query_shape=query_shape, memory_flange=memory_flange)
output = combine_heads_2d(output)
output = Conv2D(out_channel, (3, 3), strides=(1, 1), padding="same", use_bias=False)(output)
return output