当前位置: 首页>>代码示例>>Python>>正文


Python backend.get_session方法代码示例

本文整理汇总了Python中tensorflow.keras.backend.get_session方法的典型用法代码示例。如果您正苦于以下问题:Python backend.get_session方法的具体用法?Python backend.get_session怎么用?Python backend.get_session使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.keras.backend的用法示例。


在下文中一共展示了backend.get_session方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def main(base_model_name, weights_file, export_path):
    # Load model and weights
    nima = Nima(base_model_name, weights=None)
    nima.build()
    nima.nima_model.load_weights(weights_file)

    # Tell keras that this will be used for making predictions
    K.set_learning_phase(0)

    # CustomObject required by MobileNet
    with CustomObjectScope({'relu6': relu6, 'DepthwiseConv2D': DepthwiseConv2D}):
        builder = saved_model_builder.SavedModelBuilder(export_path)
        signature = predict_signature_def(
            inputs={'input_image': nima.nima_model.input},
            outputs={'quality_prediction': nima.nima_model.output}
        )

        builder.add_meta_graph_and_variables(
            sess=K.get_session(),
            tags=[tag_constants.SERVING],
            signature_def_map={'image_quality': signature}
        )
        builder.save()

    print(f'TF model exported to: {export_path}') 
开发者ID:idealo,项目名称:image-quality-assessment,代码行数:27,代码来源:save_tfs_model.py

示例2: secure_model

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def secure_model(model, **kwargs):
    """Secure a plaintext model from the current session."""
    session = K.get_session()
    min_graph = graph_util.convert_variables_to_constants(
        session, session.graph_def, [node.op.name for node in model.outputs]
    )
    graph_fname = "model.pb"
    tf.train.write_graph(min_graph, _TMPDIR, graph_fname, as_text=False)

    if "batch_size" in kwargs:
        batch_size = kwargs.pop("batch_size")
    else:
        batch_size = 1

    graph_def, inputs = load_graph(
        os.path.join(_TMPDIR, graph_fname), batch_size=batch_size
    )

    c = tfe.convert.convert.Converter(tfe.convert.registry(), **kwargs)
    y = c.convert(remove_training_nodes(graph_def), "input-provider", inputs)

    return PrivateModel(y) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:24,代码来源:private_model.py

示例3: __init__

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def __init__(self, model, shape):
        shape = (None, shape[0], shape[1], shape[2])
        x_name = 'image_tensor_x'
        with K.get_session() as sess:
            x_tensor = tf.placeholder(tf.float32, shape, x_name)
            K.set_learning_phase(0)
            y_tensor = model(x_tensor)
            y_name = [y_tensor[-1].name[:-2], y_tensor[-2].name[:-2]]
            graph = sess.graph.as_graph_def()
            graph0 = tf.graph_util.convert_variables_to_constants(sess, graph, y_name)
            graph1 = tf.graph_util.remove_training_nodes(graph0)

        self.x_name = [x_name]
        self.y_name = y_name
        self.frozen = graph1
        self.model = model 
开发者ID:csvance,项目名称:keras-mobile-detectnet,代码行数:18,代码来源:model.py

示例4: infer

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def infer(self, yield_single_examples=False):
    ''' only for infer '''
    #load data
    mode = utils.INFER
    # data must be init before model build
    infer_ds, infer_task = self.input_data(mode=mode)
    infer_gen = tf.data.make_one_shot_iterator(infer_ds)

    self.model_fn(mode=mode)
    assert self._built

    #load model
    infer_func = self.get_metric_func()

    for _ in range(len(infer_task)):
      batch_data = tf.keras.backend.get_session().run(infer_gen.get_next()[0])
      batch_input = batch_data['inputs']
      batch_uttid = batch_data['uttids'].tolist()
      batch_predict = infer_func(batch_input)[0]
      batch_decode = py_ctc.ctc_greedy_decode(batch_predict, 0, unique=True)
      for utt_index, uttid in enumerate(batch_uttid):
        logging.info("utt ID: {}".format(uttid))
        logging.info("infer result: {}".format(batch_decode[utt_index])) 
开发者ID:didi,项目名称:delta,代码行数:25,代码来源:asr_solver.py

示例5: main

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def main():
    phi = 1
    weighted_bifpn = False
    model_path = 'checkpoints/2019-12-03/pascal_05_0.6283_1.1975_0.8029.h5'
    image_sizes = (512, 640, 768, 896, 1024, 1280, 1408)
    image_size = image_sizes[phi]
    classes = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
        'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor',
    ]
    num_classes = len(classes)
    score_threshold = 0.5
    model, prediction_model = efficientdet(phi=phi,
                                           weighted_bifpn=weighted_bifpn,
                                           num_classes=num_classes,
                                           score_threshold=score_threshold)
    prediction_model.load_weights(model_path, by_name=True)
    
    frozen_graph = freeze_session(K.get_session(),  output_names=[out.op.name for out in prediction_model.outputs])
    tf.train.write_graph(frozen_graph, "./checkpoints/2019-12-03/", "pascal_05.pb", as_text=False) 
开发者ID:xuannianz,项目名称:EfficientDet,代码行数:22,代码来源:freeze_model.py

示例6: reset_weights

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def reset_weights(model, session=None):
    """
    reset weights of model with the appropriate initializer.
    Note: only uses "kernel_initializer" and "bias_initializer"
    does not close session.

    Reference:
    https://www.codementor.io/nitinsurya/how-to-re-initialize-keras-model-weights-et41zre2g

    Parameters:
        model: keras model to reset
        session (optional): the current session
    """

    if session is None:
        session = K.get_session()

    for layer in model.layers: 
        reset = False
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)
            reset = True
        
        if hasattr(layer, 'bias_initializer'):
            layer.bias.initializer.run(session=session)
            reset = True
        
        if not reset:
            print('Could not find initializer for layer %s. skipping', layer.name) 
开发者ID:adalca,项目名称:neuron,代码行数:31,代码来源:utils.py

示例7: tpu_compatible

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def tpu_compatible():
    '''Fit the tpu problems we meet while using keras tpu model'''
    if not hasattr(tpu_compatible, 'once'):
        tpu_compatible.once = True
    else:
        return
    import tensorflow as tf
    import tensorflow.keras.backend as K
    _version = tf.__version__.split('.')
    is_correct_version = int(_version[0]) >= 1 and (int(_version[0]) >= 2 or int(_version[1]) >= 13)
    from tensorflow.contrib.tpu.python.tpu.keras_support import KerasTPUModel
    def initialize_uninitialized_variables():
        sess = K.get_session()
        uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])
        init_op = tf.variables_initializer(
            [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]
        )
        sess.run(init_op)

    _tpu_compile = KerasTPUModel.compile

    def tpu_compile(self,
                    optimizer,
                    loss=None,
                    metrics=None,
                    loss_weights=None,
                    sample_weight_mode=None,
                    weighted_metrics=None,
                    target_tensors=None,
                    **kwargs):
        if not is_correct_version:
            raise ValueError('You need tensorflow >= 1.3 for better keras tpu support!')
        _tpu_compile(self, optimizer, loss, metrics, loss_weights,
                     sample_weight_mode, weighted_metrics,
                     target_tensors, **kwargs)
        initialize_uninitialized_variables()  # for unknown reason, we should run this after compile sometimes

    KerasTPUModel.compile = tpu_compile 
开发者ID:Separius,项目名称:BERT-keras,代码行数:40,代码来源:__init__.py

示例8: export_split_edge_case

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_split_edge_case(filename, input_shape):
    model, _ = _keras_model_core(split_edge_case_builder, shape=input_shape)

    sess = K.get_session()
    output = model.output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py

示例9: export_flatten

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_flatten(filename, input_shape):
    model = Sequential()
    model.add(Flatten(input_shape=input_shape[1:]))
    model.predict(np.random.uniform(size=input_shape))

    sess = K.get_session()
    output = model.get_layer("flatten").output

    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:11,代码来源:convert_test.py

示例10: export_keras_multilayer

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_keras_multilayer(filename, input_shape):
    model, _ = _keras_model_core(keras_multilayer_builder, shape=input_shape)

    sess = K.get_session()
    output = model.output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py

示例11: export_keras_conv2d

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_keras_conv2d(filename, input_shape):
    model, _ = _keras_conv2d_core(shape=input_shape)

    sess = K.get_session()
    output = model.get_layer("conv2d").output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py

示例12: export_keras_dense

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_keras_dense(filename, input_shape):
    model, _ = _keras_dense_core(shape=input_shape)

    sess = K.get_session()
    output = model.get_layer("dense").output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py

示例13: export_keras_batchnorm

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_keras_batchnorm(filename, input_shape):
    model, _ = _keras_batchnorm_core(shape=input_shape)

    sess = K.get_session()
    output = model.get_layer("batch_normalization").output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py

示例14: export_keras_global_avgpool

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_keras_global_avgpool(filename, input_shape):
    model, _ = _keras_global_avgpool_core(shape=input_shape)

    sess = K.get_session()
    output = model.get_layer("global_average_pooling2d").output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py

示例15: export_keras_global_maxpool

# 需要导入模块: from tensorflow.keras import backend [as 别名]
# 或者: from tensorflow.keras.backend import get_session [as 别名]
def export_keras_global_maxpool(filename, input_shape):
    model, _ = _keras_global_maxpool_core(shape=input_shape)

    sess = K.get_session()
    output = model.get_layer("global_max_pooling2d").output
    return export(output, filename, sess=sess) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:8,代码来源:convert_test.py


注:本文中的tensorflow.keras.backend.get_session方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。