本文整理汇总了Java中org.tensorflow.SavedModelBundle类的典型用法代码示例。如果您正苦于以下问题:Java SavedModelBundle类的具体用法?Java SavedModelBundle怎么用?Java SavedModelBundle使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
SavedModelBundle类属于org.tensorflow包,在下文中一共展示了SavedModelBundle类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: loadModel
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public SavedModelBundle loadModel(final Location source,
final String modelName, final String... tags) throws IOException
{
final String key = modelName + "/" + Arrays.toString(tags);
// If the model is already cached in memory, return it.
if (models.containsKey(key)) return models.get(key);
// Get a local directory with unpacked model data.
final File modelDir = modelDir(source, modelName);
// Load the saved model.
final SavedModelBundle model = //
SavedModelBundle.load(modelDir.getAbsolutePath(), tags);
return model;
}
示例2: dispose
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public void dispose() {
// Dispose models.
for (final SavedModelBundle model : models.values()) {
model.close();
}
models.clear();
// Dispose graphs.
for (final Graph graph : graphs.values()) {
graph.close();
}
graphs.clear();
// Dispose labels.
labelses.clear();
}
示例3: importGraph
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
TensorFlowModel result = new TensorFlowModel();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
importInputs(signatureEntry.getValue().getInputsMap(), signature);
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
String outputName = output.getKey();
try {
NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef());
importNode(node, graph.getGraphDef(), model, result);
signature.output(outputName, nameOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
signature.skippedOutput(outputName, Exceptions.toMessageString(e));
}
}
}
return result;
}
示例4: tensorFunctionOf
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
// TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
switch (tfNode.getOp().toLowerCase()) {
case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
case "elu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.elu());
case "identity" : return operationMapper.identity(tfNode, model, result);
case "placeholder" : return operationMapper.placeholder(tfNode, result);
case "relu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.relu());
case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
case "sigmoid": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.sigmoid());
case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
示例5: identity
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
if ( ! tfNode.getName().endsWith("/read"))
throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
"nodes are only supported when reading variables");
if (tfNode.getInputList().size() != 1)
throw new IllegalArgumentException("A Variable/read node must have one input but has " +
tfNode.getInputList().size());
String name = tfNode.getInput(0);
AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
if (shapes == null)
throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
Session.Runner fetched = model.session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
if ( importedTensors.size() != 1)
throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
importedTensors.size());
Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
result.constant(name, constant);
return new TypedTensorFunction(constant.type(),
new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
示例6: testLoadModel
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public void testLoadModel() throws Exception {
String modelDir = "examples/tensorflow/estimator/model";
SavedModelBundle bundle = SavedModelBundle.load(modelDir + "/" + SpongeUtils.getLastSubdirectory(modelDir), "serve");
try (Session s = bundle.session()/* ; Tensor output = s.runner().fetch("MyConst").run().get(0) */) {
Tensor x = Tensor.create(new float[] { 2, 5, 8, 1 });
Tensor y = s.runner().feed("x", x).fetch("y").run().get(0);
logger.info("y = {}", y.floatValue());
}
}
示例7: importModel
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a .pbtxt or .pb file.
* The name of the model is taken as the db/pbtxt file name (not including the file ending).
*
* @param modelDir the directory containing the TensorFlow model files to import
*/
public TensorFlowModel importModel(String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
return importModel(model);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
}
}
示例8: importNode
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
// We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
// will be used
result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
return function;
}
示例9: assertEqualResult
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
}
示例10: tensorFlowExecute
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
Session.Runner runner = model.session().runner();
org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
runner.feed(inputName, placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
return new TensorConverter().toVespaTensor(results.get(0));
}
示例11: run
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public void run() {
try {
validateFormat(originalImage);
RandomAccessibleInterval<FloatType> normalizedImage = normalize(originalImage);
final long loadModelStart = System.nanoTime();
final HTTPLocation source = new HTTPLocation(MODEL_URL);
final SavedModelBundle model = //
tensorFlowService.loadModel(source, MODEL_NAME, MODEL_TAG);
final long loadModelEnd = System.nanoTime();
log.info(String.format(
"Loaded microscope focus image quality model in %dms", (loadModelEnd -
loadModelStart) / 1000000));
// Extract names from the model signature.
// The strings "input", "probabilities" and "patches" are meant to be
// in sync with the model exporter (export_saved_model()) in Python.
final SignatureDef sig = MetaGraphDef.parseFrom(model.metaGraphDef())
.getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
try (final Tensor inputTensor = Tensors.tensor(normalizedImage)) {
// Run the model.
final long runModelStart = System.nanoTime();
final List<Tensor> fetches = model.session().runner() //
.feed(opName(sig.getInputsOrThrow("input")), inputTensor) //
.fetch(opName(sig.getOutputsOrThrow("probabilities"))) //
.fetch(opName(sig.getOutputsOrThrow("patches"))) //
.run();
final long runModelEnd = System.nanoTime();
log.info(String.format("Ran image through model in %dms", //
(runModelEnd - runModelStart) / 1000000));
// Process the results.
try (final Tensor probabilities = fetches.get(0);
final Tensor patches = fetches.get(1))
{
processPatches(probabilities, patches);
}
}
}
catch (final Exception exc) {
// Use the LogService to report the error.
log.error(exc);
}
}
示例12: close
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public void close(){
SavedModelBundle bundle = getBundle();
bundle.close();
}
示例13: getSession
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public Session getSession(){
SavedModelBundle bundle = getBundle();
return bundle.session();
}
示例14: getGraph
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public Graph getGraph(){
SavedModelBundle bundle = getBundle();
return bundle.graph();
}
示例15: getBundle
import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public SavedModelBundle getBundle(){
return this.bundle;
}