当前位置: 首页>>代码示例>>Python>>正文


Python applications.Xception方法代码示例

本文整理汇总了Python中keras.applications.Xception方法的典型用法代码示例。如果您正苦于以下问题:Python applications.Xception方法的具体用法?Python applications.Xception怎么用?Python applications.Xception使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在keras.applications的用法示例。


在下文中一共展示了applications.Xception方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_bare_keras_module

# 需要导入模块: from keras import applications [as 别名]
# 或者: from keras.applications import Xception [as 别名]
def test_bare_keras_module(self):
        """ Keras GraphFunctions should give the same result as standard Keras models """
        img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))

        for model_gen, preproc_fn, target_size in [(InceptionV3, iv3.preprocess_input, model_sizes['InceptionV3']),
                                      (Xception, xcpt.preprocess_input, model_sizes['Xception']),
                                      (ResNet50, rsnt.preprocess_input, model_sizes['ResNet50'])]:

            keras_model = model_gen(weights="imagenet")
            _preproc_img_list = []
            for fpath in img_fpaths:
                img = load_img(fpath, target_size=target_size)
                # WARNING: must apply expand dimensions first, or ResNet50 preprocessor fails
                img_arr = np.expand_dims(img_to_array(img), axis=0)
                _preproc_img_list.append(preproc_fn(img_arr))

            imgs_input = np.vstack(_preproc_img_list)

            preds_ref = keras_model.predict(imgs_input)

            gfn_bare_keras = GraphFunction.fromKeras(keras_model)

            with IsolatedSession(using_keras=True) as issn:
                K.set_learning_phase(0)
                feeds, fetches = issn.importGraphFunction(gfn_bare_keras)
                preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_input})

            np.testing.assert_array_almost_equal(preds_tgt,
                                                 preds_ref,
                                                 decimal=self.featurizerCompareDigitsExact) 
开发者ID:databricks,项目名称:spark-deep-learning,代码行数:32,代码来源:test_pieces.py

示例2: test_pipeline

# 需要导入模块: from keras import applications [as 别名]
# 或者: from keras.applications import Xception [as 别名]
def test_pipeline(self):
        """ Pipeline should provide correct function composition """
        img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg'))

        xcpt_model = Xception(weights="imagenet")
        stages = [('spimage', gfac.buildSpImageConverter('BGR', 'float32')),
                  ('xception', GraphFunction.fromKeras(xcpt_model))]
        piped_model = GraphFunction.fromList(stages)

        for fpath in img_fpaths:
            target_size = model_sizes['Xception']
            img = load_img(fpath, target_size=target_size)
            img_arr = np.expand_dims(img_to_array(img), axis=0)
            img_input = xcpt.preprocess_input(img_arr)
            preds_ref = xcpt_model.predict(img_input)

            spimg_input_dict = imageArrayToStruct(img_input).asDict()
            spimg_input_dict['data'] = bytes(spimg_input_dict['data'])
            with IsolatedSession() as issn:
                # Need blank import scope name so that spimg fields match the input names
                feeds, fetches = issn.importGraphFunction(piped_model, prefix="")
                feed_dict = dict(
                    (tnsr, spimg_input_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds)
                preds_tgt = issn.run(fetches[0], feed_dict=feed_dict)
                # Uncomment the line below to see the graph
                # tfx.write_visualization_html(issn.graph,
                # NamedTemporaryFile(prefix="gdef", suffix=".html").name)

            np.testing.assert_array_almost_equal(preds_tgt,
                                                 preds_ref,
                                                 decimal=self.featurizerCompareDigitsExact) 
开发者ID:databricks,项目名称:spark-deep-learning,代码行数:33,代码来源:test_pieces.py

示例3: test_validate_keras_xception

# 需要导入模块: from keras import applications [as 别名]
# 或者: from keras.applications import Xception [as 别名]
def test_validate_keras_xception(self):
        input_tensor = Input(shape=(224, 224, 3))
        model = Xception(weights="imagenet", input_tensor=input_tensor)
        file_name = "keras"+model.name+".pmml"
        pmml_obj = KerasToPmml(model,dataSet="image",predictedClasses=[str(i) for i in range(1000)])
        pmml_obj.export(open(file_name,'w'),0)
        self.assertEqual(self.schema.is_valid(file_name), True) 
开发者ID:nyoka-pmml,项目名称:nyoka,代码行数:9,代码来源:_validateSchema.py

示例4: test_xception

# 需要导入模块: from keras import applications [as 别名]
# 或者: from keras.applications import Xception [as 别名]
def test_xception():
    app = applications.Xception
    last_dim = 2048
    _test_application_basic(app)
    _test_application_notop(app, last_dim)
    _test_application_variable_input_channels(app, last_dim)
    _test_app_pooling(app, last_dim) 
开发者ID:hello-sea,项目名称:DeepLearning_Wavelet-LSTM,代码行数:9,代码来源:applications_test.py

示例5: get_imagenet_architecture

# 需要导入模块: from keras import applications [as 别名]
# 或者: from keras.applications import Xception [as 别名]
def get_imagenet_architecture(architecture, variant, size, alpha, output_layer, include_top=False, weights='imagenet'):
    from keras import applications, Model

    if include_top:
        assert output_layer == 'last'

    if size == 'auto':
        size = get_image_size(architecture, variant, size)

    shape = (size, size, 3)

    if architecture == 'densenet':
        if variant == 'auto':
            variant = 'densenet-121'
        if variant == 'densenet-121':
            model = applications.DenseNet121(weights=weights, include_top=include_top, input_shape=shape)
        elif variant == 'densenet-169':
            model = applications.DenseNet169(weights=weights, include_top=include_top, input_shape=shape)
        elif variant == 'densenet-201':
            model = applications.DenseNet201(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'inception-resnet-v2':
        model = applications.InceptionResNetV2(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'mobilenet':
        model = applications.MobileNet(weights=weights, include_top=include_top, input_shape=shape, alpha=alpha)
    elif architecture == 'mobilenet-v2':
        model = applications.MobileNetV2(weights=weights, include_top=include_top, input_shape=shape, alpha=alpha)
    elif architecture == 'nasnet':
        if variant == 'auto':
            variant = 'large'
        if variant == 'large':
            model = applications.NASNetLarge(weights=weights, include_top=include_top, input_shape=shape)
        else:
            model = applications.NASNetMobile(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'resnet-50':
        model = applications.ResNet50(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'vgg-16':
        model = applications.VGG16(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'vgg-19':
        model = applications.VGG19(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'xception':
        model = applications.Xception(weights=weights, include_top=include_top, input_shape=shape)
    elif architecture == 'inception-v3':
        model = applications.InceptionV3(weights=weights, include_top=include_top, input_shape=shape)

    if output_layer != 'last':
        try:
            if isinstance(output_layer, int):
                layer = model.layers[output_layer]
            else:
                layer = model.get_layer(output_layer)
        except Exception:
            raise VergeMLError('layer not found: {}'.format(output_layer))
        model = Model(inputs=model.input, outputs=layer.output)

    return model 
开发者ID:mme,项目名称:vergeml,代码行数:57,代码来源:features.py


注:本文中的keras.applications.Xception方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。