当前位置: 首页>>代码示例>>Python>>正文


Python common_attention.local_attention_2d方法代码示例

本文整理汇总了Python中tensor2tensor.layers.common_attention.local_attention_2d方法的典型用法代码示例。如果您正苦于以下问题:Python common_attention.local_attention_2d方法的具体用法?Python common_attention.local_attention_2d怎么用?Python common_attention.local_attention_2d使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensor2tensor.layers.common_attention的用法示例。


在下文中一共展示了common_attention.local_attention_2d方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testLocalUnmaskedAttention2D

# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def testLocalUnmaskedAttention2D(self, batch, heads, length,
                                   depth_k, depth_v, query_shape):
    if batch is None:
      batch = tf.random_uniform([], minval=0, maxval=5, dtype=tf.int32)
    q = tf.random_normal([batch, heads, length, length, depth_k])
    k = tf.random_normal([batch, heads, length, length, depth_k])
    v = tf.random_normal([batch, heads, length, length, depth_v])
    output = common_attention.local_attention_2d(
        q,
        k,
        v,
        query_shape=query_shape,
        memory_flange=(3, 3))
    if isinstance(batch, tf.Tensor):
      batch, res = self.evaluate([batch, output])
    else:
      res = self.evaluate(output)

    self.assertEqual(res.shape, (batch, heads, length, length, depth_v)) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:21,代码来源:common_attention_test.py

示例2: load_model

# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def load_model(model_path):
    custom_layers = {
        "multihead_attention": multihead_attention,
        "Conv2D": L.Conv2D,
        "split_heads_2d": split_heads_2d,
        "local_attention_2d": local_attention_2d,
        "combine_heads_2d": combine_heads_2d
    }
    model = model_from_yaml(open(os.path.join(model_path, "arch.yaml")).read(), custom_objects=custom_layers)

    full_path = os.path.join(model_path, "weights.h5")
    with h5py.File(full_path, "r") as w:
        keys = list(w.keys())
        is_para = any(["model" in k for k in keys])

    if is_para:
        para_model = multi_gpu_model(model, gpus=2)
        para_model.load_weights(full_path)
        model = para_model.layers[-2]
    else:
        model.load_weights(full_path)

    print("Model " + model_path + " loaded")
    return model 
开发者ID:BreezeWhite,项目名称:Music-Transcription-with-Semantic-Segmentation,代码行数:26,代码来源:utils.py

示例3: testLocalUnmaskedAttention2D

# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def testLocalUnmaskedAttention2D(self):
    x = np.random.rand(5, 4, 25, 25, 16)
    y = np.random.rand(5, 4, 25, 25, 16)
    with self.test_session() as session:
      a = common_attention.local_attention_2d(
          tf.constant(x, dtype=tf.float32),
          tf.constant(y, dtype=tf.float32),
          tf.constant(y, dtype=tf.float32),
          query_shape=(4, 4),
          memory_flange=(3, 3))
      session.run(tf.global_variables_initializer())
      res = session.run(a)
    self.assertEqual(res.shape, (5, 4, 25, 25, 16)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:15,代码来源:common_attention_test.py

示例4: testLocalUnmaskedAttention2DMatchingBlockLength

# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def testLocalUnmaskedAttention2DMatchingBlockLength(self):
    x = np.random.rand(5, 4, 25, 25, 16)
    y = np.random.rand(5, 4, 25, 25, 16)
    with self.test_session() as session:
      a = common_attention.local_attention_2d(
          tf.constant(x, dtype=tf.float32),
          tf.constant(y, dtype=tf.float32),
          tf.constant(y, dtype=tf.float32),
          query_shape=(5, 5),
          memory_flange=(3, 3))
      session.run(tf.global_variables_initializer())
      res = session.run(a)
    self.assertEqual(res.shape, (5, 4, 25, 25, 16)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:15,代码来源:common_attention_test.py

示例5: multihead_attention

# 需要导入模块: from tensor2tensor.layers import common_attention [as 别名]
# 或者: from tensor2tensor.layers.common_attention import local_attention_2d [as 别名]
def multihead_attention(x, out_channel=64, d_model=32, n_heads=8, query_shape=(128, 24), memory_flange=(8, 8)):
    q = Conv2D(d_model, (3, 3), strides=(1, 1), padding="same", name="gen_q_conv")(x)
    k = Conv2D(d_model, (3, 3), strides=(1, 1), padding="same", name="gen_k_conv")(x)
    v = Conv2D(d_model, (3, 3), strides=(1, 1), padding="same", name="gen_v_conv")(x)

    q = split_heads_2d(q, n_heads)
    k = split_heads_2d(k, n_heads)
    v = split_heads_2d(v, n_heads)

    k_depth_per_head = d_model // n_heads
    q *= k_depth_per_head**-0.5

    
    """
    # local attetion 2d
    v_shape = K.int_shape(v)
    q = pad_to_multiple(q, query_shape)
    k = pad_to_multiple(k, query_shape)
    v = pad_to_multiple(v, query_shape)

    paddings = ((0, 0), (memory_flange[0], memory_flange[1]), (memory_flange[0], memory_flange[1]))
    k = L.ZeroPadding3D(padding=paddings)(k)
    v = L.ZeroPadding3D(padding=paddings)(v)
    
    # Set up query blocks
    q_indices = gather_indices_2d(q, query_shape, query_shape)
    q_new = gather_blocks_2d(q, q_indices)

    # Set up key and value blocks
    memory_shape = (query_shape[0] + 2*memory_flange[0],
                    query_shape[1] + 2*memory_flange[1])
    k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape)
    k_new = gather_blocks_2d(k, k_and_v_indices)
    v_new = gather_blocks_2d(v, k_and_v_indices)

    output = dot_attention(q_new, k_new, v_new)

    # Put output back into original shapes
    padded_shape = K.shape(q)
    output = scatter_blocks_2d(output, q_indices, padded_shape) 

    # Remove padding
    output = K.slice(output, [0, 0, 0, 0, 0], [-1, -1, v_shape[2], v_shape[3], -1])
    """

    output = local_attention_2d(q, k, v, query_shape=query_shape, memory_flange=memory_flange)
    
    output = combine_heads_2d(output)
    output = Conv2D(out_channel, (3, 3), strides=(1, 1), padding="same", use_bias=False)(output)
    
    return output 
开发者ID:BreezeWhite,项目名称:Music-Transcription-with-Semantic-Segmentation,代码行数:53,代码来源:model_attn.py


注:本文中的tensor2tensor.layers.common_attention.local_attention_2d方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。