当前位置: 首页>>代码示例>>Java>>正文


Java SavedModelBundle.load方法代码示例

本文整理汇总了Java中org.tensorflow.SavedModelBundle.load方法的典型用法代码示例。如果您正苦于以下问题:Java SavedModelBundle.load方法的具体用法?Java SavedModelBundle.load怎么用?Java SavedModelBundle.load使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.tensorflow.SavedModelBundle的用法示例。


在下文中一共展示了SavedModelBundle.load方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: loadModel

import org.tensorflow.SavedModelBundle; //导入方法依赖的package包/类
@Override
public SavedModelBundle loadModel(final Location source,
	final String modelName, final String... tags) throws IOException
{
	final String key = modelName + "/" + Arrays.toString(tags);

	// If the model is already cached in memory, return it.
	if (models.containsKey(key)) return models.get(key);

	// Get a local directory with unpacked model data.
	final File modelDir = modelDir(source, modelName);

	// Load the saved model.
	final SavedModelBundle model = //
		SavedModelBundle.load(modelDir.getAbsolutePath(), tags);

	return model;
}
 
开发者ID:imagej,项目名称:imagej-tensorflow,代码行数:19,代码来源:DefaultTensorFlowService.java

示例2: testLoadModel

import org.tensorflow.SavedModelBundle; //导入方法依赖的package包/类
public void testLoadModel() throws Exception {
    String modelDir = "examples/tensorflow/estimator/model";
    SavedModelBundle bundle = SavedModelBundle.load(modelDir + "/" + SpongeUtils.getLastSubdirectory(modelDir), "serve");

    try (Session s = bundle.session()/* ; Tensor output = s.runner().fetch("MyConst").run().get(0) */) {
        Tensor x = Tensor.create(new float[] { 2, 5, 8, 1 });
        Tensor y = s.runner().feed("x", x).fetch("y").run().get(0);

        logger.info("y = {}", y.floatValue());
    }
}
 
开发者ID:softelnet,项目名称:sponge,代码行数:12,代码来源:TensorflowTest.java

示例3: importModel

import org.tensorflow.SavedModelBundle; //导入方法依赖的package包/类
/**
 * Imports a saved TensorFlow model from a directory.
 * The model should be saved as a .pbtxt or .pb file.
 * The name of the model is taken as the db/pbtxt file name (not including the file ending).
 *
 * @param modelDir the directory containing the TensorFlow model files to import
 */
public TensorFlowModel importModel(String modelDir) {
    try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
        return importModel(model);
    }
    catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
    }
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:16,代码来源:TensorFlowImporter.java

示例4: createBatch

import org.tensorflow.SavedModelBundle; //导入方法依赖的package包/类
@Override
protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate){
	ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate){

		@Override
		public IntegrationTest getIntegrationTest(){
			return EstimatorTest.this;
		}

		@Override
		public PMML getPMML() throws Exception {
			File savedModelDir = getSavedModelDir();

			SavedModelBundle bundle = SavedModelBundle.load(savedModelDir.getAbsolutePath(), "serve");

			try(SavedModel savedModel = new SavedModel(bundle)){
				EstimatorFactory estimatorFactory = EstimatorFactory.newInstance();

				Estimator estimator = estimatorFactory.newEstimator(savedModel);

				PMML pmml = estimator.encodePMML();

				ensureValidity(pmml);

				return pmml;
			}
		}

		private File getSavedModelDir() throws IOException, URISyntaxException {
			ClassLoader classLoader = (EstimatorTest.this.getClass()).getClassLoader();

			String protoPath = ("savedmodel/" + getName() + getDataset() + "/saved_model.pbtxt");

			URL protoResource = classLoader.getResource(protoPath);
			if(protoResource == null){
				throw new NoSuchFileException(protoPath);
			}

			File protoFile = (Paths.get(protoResource.toURI())).toFile();

			return protoFile.getParentFile();
		}
	};

	return result;
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:47,代码来源:EstimatorTest.java

示例5: testMnistSoftmaxImport

import org.tensorflow.SavedModelBundle; //导入方法依赖的package包/类
@Test
public void testMnistSoftmaxImport() {
    String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
    SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
    TensorFlowModel result = new TensorFlowImporter().importModel(model);

    // Check constants
    assertEquals(2, result.constants().size());

    Tensor constant0 = result.constants().get("Variable");
    assertNotNull(constant0);
    assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
                 constant0.type());
    assertEquals(7840, constant0.size());

    Tensor constant1 = result.constants().get("Variable_1");
    assertNotNull(constant1);
    assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
                 constant1.type());
    assertEquals(10, constant1.size());

    // Check signatures
    assertEquals(1, result.signatures().size());
    TensorFlowModel.Signature signature = result.signatures().get("serving_default");
    assertNotNull(signature);

    // ... signature inputs
    assertEquals(1, signature.inputs().size());
    TensorType argument0 = signature.inputArgument("x");
    assertNotNull(argument0);
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);

    // ... signature outputs
    assertEquals(1, signature.outputs().size());
    RankingExpression output = signature.outputExpression("y");
    assertNotNull(output);
    assertEquals("add", output.getName());
    assertEquals("" +
                 "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
                 "rename(constant(Variable_1), d0, d1), " +
                 "f(a,b)(a + b))",
                 toNonPrimitiveString(output));

    // Test execution
    assertEqualResult(model, result, "Placeholder", "Variable/read");
    assertEqualResult(model, result, "Placeholder", "Variable_1/read");
    assertEqualResult(model, result, "Placeholder", "MatMul");
    assertEqualResult(model, result, "Placeholder", "add");
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:50,代码来源:TensorflowImportTestCase.java


注:本文中的org.tensorflow.SavedModelBundle.load方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。