当前位置: 首页>>代码示例>>Python>>正文


Python save.save函数代码示例

本文整理汇总了Python中tensorflow.python.saved_model.save.save函数的典型用法代码示例。如果您正苦于以下问题:Python save函数的具体用法?Python save怎么用?Python save使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了save函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_signature_attribute_reserved

 def test_signature_attribute_reserved(self):
   root = util.Checkpoint(signatures=variables.Variable(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(ValueError, "del obj.signatures"):
     save.save(root, save_dir)
   del root.signatures
   save.save(root, save_dir)
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:save_test.py

示例2: test_function_with_captured_dataset

  def test_function_with_captured_dataset(self):
    if test_util.is_gpu_available():
      self.skipTest("Currently broken when a GPU is available.")

    class HasDataset(module.Module):

      def __init__(self):
        super(HasDataset, self).__init__()
        self.dataset = (
            dataset_ops.Dataset.range(5)
            .map(lambda x: x ** 2))

      @def_function.function
      def __call__(self, x):
        current_sum = array_ops.zeros([], dtype=dtypes.int64)
        for element in self.dataset:
          current_sum += x * element
        return current_sum

    root = HasDataset()
    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save.save(
        root, save_dir,
        signatures=root.__call__.get_concrete_function(
            tensor_spec.TensorSpec(None, dtypes.int64)))
    self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)},
                        _import_and_infer(save_dir, {"x": 3}))
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:save_test.py

示例3: testConstSavedModel

  def testConstSavedModel(self):
    """Test a basic model with functions to make sure functions are inlined."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.f = def_function.function(lambda x: 2. * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(variable_graph_def))
    self.assertTrue(variable_graph_def.library.function)

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(constant_graph_def.library.function)

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:convert_to_constants_test.py

示例4: test_table

 def test_table(self):
   initializer = lookup_ops.TextFileInitializer(
       self._vocab_path,
       key_dtype=dtypes.string,
       key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
       value_dtype=dtypes.int64,
       value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
   root = util.Checkpoint(table=lookup_ops.HashTable(
       initializer, default_value=-1))
   root.table_user = def_function.function(
       root.table.lookup,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
   self.assertEqual(
       2,
       self.evaluate(root.table_user(constant_op.constant("gamma"))))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir)
   file_io.delete_file(self._vocab_path)
   self.assertAllClose(
       {"output_0": [2, 0]},
       _import_and_infer(save_dir, {"keys": ["gamma", "alpha"]}))
   second_dir = os.path.join(self.get_temp_dir(), "second_dir")
   # Asset paths should track the location the SavedModel is loaded from.
   file_io.rename(save_dir, second_dir)
   self.assertAllClose(
       {"output_0": [2, 1]},
       _import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:save_test.py

示例5: test_asset_loading

  def test_asset_loading(self):
    first_path = self._v1_asset_saved_model()
    imported = load.load(first_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = imported.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))
    second_path = os.path.join(self.get_temp_dir(), "saved_model",
                               str(ops.uid()))
    save.save(imported, second_path, signatures=imported.signatures)
    shutil.rmtree(first_path)
    del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
    second_import = load.load(second_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = second_import.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))

    third_path = os.path.join(self.get_temp_dir(), "saved_model",
                              str(ops.uid()))
    save.save(second_import, third_path, signatures=second_import.signatures)
    shutil.rmtree(second_path)
    del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
    third_import = load.load(third_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = third_import.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:load_v1_in_v2_test.py

示例6: testVariableSavedModel

  def testVariableSavedModel(self):
    """Test a basic model with Variables with saving/loading the SavedModel."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:28,代码来源:convert_to_constants_test.py

示例7: test_non_concrete_error

 def test_non_concrete_error(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "must be converted to concrete functions"):
     save.save(root, save_dir, root.f)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:save_test.py

示例8: test_non_concrete_error

 def test_non_concrete_error(self):
   root = tracking.AutoTrackable()
   root.f = def_function.function(lambda x: 2. * x)
   root.f(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "Expected a TensorFlow function"):
     save.save(root, save_dir, root.f)
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:save_test.py

示例9: test_single_function_default_signature

 def test_single_function_default_signature(self):
   model = tracking.AutoCheckpointable()
   model.f = def_function.function(lambda: 3., input_signature=())
   model.f()
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(model, save_dir)
   self.assertAllClose({"output_0": 3.},
                       _import_and_infer(save_dir, {}))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:save_test.py

示例10: test_export_functional_keras_model

 def test_export_functional_keras_model(self):
   x = input_layer.Input((4,), name="x")
   y = core.Dense(4, name="out")(x)
   model = training.Model(x, y)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(model, save_dir)
   self.assertAllClose(
       {"out": model(array_ops.ones([1, 4]))},
       _import_and_infer(save_dir, {"x": [[1., 1., 1., 1.]]}))
开发者ID:aeverall,项目名称:tensorflow,代码行数:9,代码来源:save_test.py

示例11: test_nested_outputs

 def test_nested_outputs(self):
   root = tracking.AutoCheckpointable()
   root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x)))
   root.f(constant_op.constant(1.))
   to_save = root.f.get_concrete_function(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "non-flat outputs"):
     save.save(root, save_dir, to_save)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:save_test.py

示例12: test_ambiguous_signatures

 def test_ambiguous_signatures(self):
   model = _ModelWithOptimizer()
   x = constant_op.constant([[3., 4.]])
   y = constant_op.constant([2.])
   model.call(x, y)
   model.second_function = def_function.function(lambda: 1.)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(ValueError, "call.*second_function"):
     save.save(model, save_dir)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:save_test.py

示例13: test_nested_dict_outputs

 def test_nested_dict_outputs(self):
   root = util.Checkpoint(
       f=def_function.function(
           lambda x: {"a": 2. * x, "b": (3. * x, 4. * x)}))
   root.f(constant_op.constant(1.))
   to_save = root.f.get_concrete_function(constant_op.constant(1.))
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   with self.assertRaisesRegexp(
       ValueError, "dictionary containing non-Tensor value"):
     save.save(root, save_dir, to_save)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:save_test.py

示例14: test_no_reference_cycles

 def test_no_reference_cycles(self):
   x = constant_op.constant([[3., 4.]])
   y = constant_op.constant([2.])
   self._model.call(x, y)
   if sys.version_info[0] < 3:
     # TODO(allenl): debug reference cycles in Python 2.x
     self.skipTest("This test only works in Python 3+. Reference cycles are "
                   "created in older Python versions.")
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(self._model, save_dir, self._model.call)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:save_test.py

示例15: test_single_method_default_signature

 def test_single_method_default_signature(self):
   model = _ModelWithOptimizer()
   x = constant_op.constant([[3., 4.]])
   y = constant_op.constant([2.])
   model.call(x, y)
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(model, save_dir)
   self.assertIn("loss",
                 _import_and_infer(save_dir,
                                   {"x": [[3., 4.]], "y": [2.]}))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:save_test.py


注:本文中的tensorflow.python.saved_model.save.save函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。