当前位置: 首页>>代码示例>>Python>>正文


Python freeze_graph.freeze_graph_with_def_protos方法代码示例

本文整理汇总了Python中tensorflow.python.tools.freeze_graph.freeze_graph_with_def_protos方法的典型用法代码示例。如果您正苦于以下问题:Python freeze_graph.freeze_graph_with_def_protos方法的具体用法?Python freeze_graph.freeze_graph_with_def_protos怎么用?Python freeze_graph.freeze_graph_with_def_protos使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.tools.freeze_graph的用法示例。


在下文中一共展示了freeze_graph.freeze_graph_with_def_protos方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_checkpoint_v1

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos [as 别名]
def load_checkpoint_v1(self):
        ckpt_path = os.path.dirname(self._tf_file)
        latest_ckpt = tf.train.latest_checkpoint(ckpt_path)
        saver = tf.train.import_meta_graph(latest_ckpt + ".meta")
        with tf.Session() as session:
            session.run(
                [
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ]
            )
            saver.restore(session, latest_ckpt)
            graph_def = session.graph.as_graph_def(add_shapes=True)
        frozen_graph = freeze_graph.freeze_graph_with_def_protos(
            input_graph_def=graph_def,
            input_saver_def=None,
            input_checkpoint=latest_ckpt,
            output_node_names=self._outputs,
            restore_op_name="",
            filename_tensor_name="",
            output_graph=None,
            clear_devices=True,
            initializer_nodes=""
        )
        return frozen_graph 
开发者ID:sony,项目名称:nnabla,代码行数:27,代码来源:importer.py

示例2: main

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos [as 别名]
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info('Prepare to export model to: %s', FLAGS.export_path)

  with tf.Graph().as_default():
    image, image_size, resized_image_size = _create_input_tensors()

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
        crop_size=FLAGS.crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if tuple(FLAGS.inference_scales) == (1.0,):
      tf.logging.info('Exported model performs single-scale inference.')
      predictions = model.predict_labels(
          image,
          model_options=model_options,
          image_pyramid=FLAGS.image_pyramid)
    else:
      tf.logging.info('Exported model performs multi-scale inference.')
      predictions = model.predict_labels_multi_scale(
          image,
          model_options=model_options,
          eval_scales=FLAGS.inference_scales,
          add_flipped_images=FLAGS.add_flipped_images)

    # Crop the valid regions from the predictions.
    semantic_predictions = tf.slice(
        predictions[common.OUTPUT_TYPE],
        [0, 0, 0],
        [1, resized_image_size[0], resized_image_size[1]])
    # Resize back the prediction to the original image size.
    def _resize_label(label, label_size):
      # Expand dimension of label to [1, height, width, 1] for resize operation.
      label = tf.expand_dims(label, 3)
      resized_label = tf.image.resize_images(
          label,
          label_size,
          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
          align_corners=True)
      return tf.squeeze(resized_label, 3)
    semantic_predictions = _resize_label(semantic_predictions, image_size)
    semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)

    saver = tf.train.Saver(tf.model_variables())

    tf.gfile.MakeDirs(os.path.dirname(FLAGS.export_path))
    freeze_graph.freeze_graph_with_def_protos(
        tf.get_default_graph().as_graph_def(add_shapes=True),
        saver.as_saver_def(),
        FLAGS.checkpoint_path,
        _OUTPUT_NAME,
        restore_op_name=None,
        filename_tensor_name=None,
        output_graph=FLAGS.export_path,
        clear_devices=True,
        initializer_nodes=None) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:60,代码来源:export_model.py

示例3: dump_parameters

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos [as 别名]
def dump_parameters(self):
        r'''
        Export the trained variables into a Protocol Buffers (.pb) file and dump into the DB
        Use a structure optimal for inference
        '''

        Config = self.c
        tf.reset_default_graph()
        input, outputs, _ = self.create_inference_graph(batch_size=-1, n_steps=-1)
        output_names_tensor = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
        output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
        output_names = ','.join(output_names_tensor + output_names_ops)

        mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
        saver = tf.train.Saver(mapping)

        # Restore variables from training checkpoint
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        checkpoint_path = checkpoint.model_checkpoint_path

        output_filename = 'output_graph.pb'
        export_temp_dir = tempfile.TemporaryDirectory()
        export_dir = export_temp_dir.name

        try:
            output_graph_path = os.path.join(export_dir, output_filename)

            def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
                return freeze_graph.freeze_graph_with_def_protos(
                    input_graph_def=tf.get_default_graph().as_graph_def(),
                    input_saver_def=saver.as_saver_def(),
                    input_checkpoint=checkpoint_path,
                    output_node_names=output_node_names,
                    restore_op_name=None,
                    filename_tensor_name=None,
                    output_graph=output_file,
                    clear_devices=False,
                    variable_names_blacklist=variables_blacklist,
                    initializer_nodes='')

            frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
            frozen_graph.version = 1

            with tf.gfile.GFile(output_graph_path, 'wb') as fout:
                fout.write(frozen_graph.SerializeToString())

            params = {}
            # Read from temp pb file & encode it to base64 string
            with open(output_graph_path, 'rb') as f:
                pb_model_bytes = f.read()

            params['pb_model_base64'] = base64.b64encode(pb_model_bytes).decode('utf-8')

            return params

        except RuntimeError as e:
            logger.log('Error occured! {}'.format(e)) 
开发者ID:nginyc,项目名称:rafiki,代码行数:59,代码来源:TfDeepSpeech.py

示例4: main

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos [as 别名]
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info('Prepare to export model to: %s', FLAGS.export_path)

  with tf.Graph().as_default():
    image, image_size, resized_image_size = _create_input_tensors()

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
        crop_size=FLAGS.crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if tuple(FLAGS.inference_scales) == (1.0,):
      tf.logging.info('Exported model performs single-scale inference.')
      predictions = model.predict_labels(
          image,
          model_options=model_options,
          image_pyramid=FLAGS.image_pyramid)
    else:
      tf.logging.info('Exported model performs multi-scale inference.')
      predictions = model.predict_labels_multi_scale(
          image,
          model_options=model_options,
          eval_scales=FLAGS.inference_scales,
          add_flipped_images=FLAGS.add_flipped_images)

    predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32)
    # Crop the valid regions from the predictions.
    semantic_predictions = tf.slice(
        predictions,
        [0, 0, 0],
        [1, resized_image_size[0], resized_image_size[1]])
    # Resize back the prediction to the original image size.
    def _resize_label(label, label_size):
      # Expand dimension of label to [1, height, width, 1] for resize operation.
      label = tf.expand_dims(label, 3)
      resized_label = tf.image.resize_images(
          label,
          label_size,
          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
          align_corners=True)
      return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
    semantic_predictions = _resize_label(semantic_predictions, image_size)
    semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)

    saver = tf.train.Saver(tf.model_variables())

    tf.gfile.MakeDirs(os.path.dirname(FLAGS.export_path))
    freeze_graph.freeze_graph_with_def_protos(
        tf.get_default_graph().as_graph_def(add_shapes=True),
        saver.as_saver_def(),
        FLAGS.checkpoint_path,
        _OUTPUT_NAME,
        restore_op_name=None,
        filename_tensor_name=None,
        output_graph=FLAGS.export_path,
        clear_devices=True,
        initializer_nodes=None) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:61,代码来源:export_model.py


注:本文中的tensorflow.python.tools.freeze_graph.freeze_graph_with_def_protos方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。