当前位置: 首页>>代码示例>>Python>>正文


Python layer_utils.convert_dense_weights_data_format方法代码示例

本文整理汇总了Python中keras.utils.layer_utils.convert_dense_weights_data_format方法的典型用法代码示例。如果您正苦于以下问题:Python layer_utils.convert_dense_weights_data_format方法的具体用法?Python layer_utils.convert_dense_weights_data_format怎么用?Python layer_utils.convert_dense_weights_data_format使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在keras.utils.layer_utils的用法示例。


在下文中一共展示了layer_utils.convert_dense_weights_data_format方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_convert_weights

# 需要导入模块: from keras.utils import layer_utils [as 别名]
# 或者: from keras.utils.layer_utils import convert_dense_weights_data_format [as 别名]
def test_convert_weights():
    def get_model(shape, data_format):
        model = Sequential()
        model.add(Conv2D(filters=2,
                         kernel_size=(4, 3),
                         input_shape=shape,
                         data_format=data_format))
        model.add(Flatten())
        model.add(Dense(5))
        return model

    for data_format in ['channels_first', 'channels_last']:
        if data_format == 'channels_first':
            shape = (3, 5, 5)
            target_shape = (5, 5, 3)
            prev_shape = (2, 3, 2)
            flip = lambda x: np.flip(np.flip(x, axis=2), axis=3)
            transpose = lambda x: np.transpose(x, (0, 2, 3, 1))
            target_data_format = 'channels_last'
        elif data_format == 'channels_last':
            shape = (5, 5, 3)
            target_shape = (3, 5, 5)
            prev_shape = (2, 2, 3)
            flip = lambda x: np.flip(np.flip(x, axis=1), axis=2)
            transpose = lambda x: np.transpose(x, (0, 3, 1, 2))
            target_data_format = 'channels_first'

        model1 = get_model(shape, data_format)
        model2 = get_model(target_shape, target_data_format)
        conv = K.function([model1.input], [model1.layers[0].output])

        x = np.random.random((1,) + shape)

        # Test equivalence of convert_all_kernels_in_model
        convout1 = conv([x])[0]
        layer_utils.convert_all_kernels_in_model(model1)
        convout2 = flip(conv([flip(x)])[0])

        assert_allclose(convout1, convout2, atol=1e-5)

        # Test equivalence of convert_dense_weights_data_format
        out1 = model1.predict(x)
        layer_utils.convert_dense_weights_data_format(model1.layers[2], prev_shape, target_data_format)
        for (src, dst) in zip(model1.layers, model2.layers):
            dst.set_weights(src.get_weights())
        out2 = model2.predict(transpose(x))

        assert_allclose(out1, out2, atol=1e-5) 
开发者ID:hello-sea,项目名称:DeepLearning_Wavelet-LSTM,代码行数:50,代码来源:layer_utils_test.py


注:本文中的keras.utils.layer_utils.convert_dense_weights_data_format方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。