示例1: getOperation

import org.tensorflow.Operation; //导入依赖的package包/类
public Operation getOperation(String id, String name) {
    Graph graph = graphs.get(id);
    if(graph != null && name != null) {
        return graph.operation(name);
    } else {
        return null;

示例2: graphOperation

import org.tensorflow.Operation; //导入依赖的package包/类
public Operation graphOperation(String operationName) {
  final Operation operation = g.operation(operationName);
  if (operation == null) {
    throw new RuntimeException(
        "Node '" + operationName + "' does not exist in model '" + modelName + "'");
  return operation;

示例3: inputListLength

import org.tensorflow.Operation; //导入依赖的package包/类
public void inputListLength(String id, String opName, String name, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
    } catch (Exception e) {

示例4: name

import org.tensorflow.Operation; //导入依赖的package包/类
public void name(String id, String opName, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
    } catch (Exception e) {

示例5: numOutputs

import org.tensorflow.Operation; //导入依赖的package包/类
public void numOutputs(String id, String opName, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
    } catch (Exception e) {

示例6: output

import org.tensorflow.Operation; //导入依赖的package包/类
public void output(String id, String opName, int index, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
    } catch (Exception e) {

示例7: outputList

import org.tensorflow.Operation; //导入依赖的package包/类
public void outputList(String id, String opName, int index, int length, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
        Output[] outputs = graphOperation.outputList(index, length);
        WritableArray outputsConverted = new WritableNativeArray();
        for (Output output : outputs) {
    } catch (Exception e) {

示例8: outputListLength

import org.tensorflow.Operation; //导入依赖的package包/类
public void outputListLength(String id, String opName, String name, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
    } catch (Exception e) {

示例9: type

import org.tensorflow.Operation; //导入依赖的package包/类
public void type(String id, String opName, Promise promise) {
    try {
        Operation graphOperation = getGraphOperation(id, opName);
    } catch (Exception e) {

示例10: create

import org.tensorflow.Operation; //导入依赖的package包/类
 * Initializes a native TensorFlow session for classifying images.
 * @param assetManager The asset manager to be used to load assets.
 * @param modelFilename The filepath of the model GraphDef protocol buffer.
 * @param locationFilename The filepath of label file for classes.
 * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
 * @param imageMean The assumed mean of the image values.
 * @param imageStd The assumed std of the image values.
 * @param inputName The label of the image input node.
 * @param outputName The label of the output node.
public static Classifier create(
    final AssetManager assetManager,
    final String modelFilename,
    final String locationFilename,
    final int imageMean,
    final float imageStd,
    final String inputName,
    final String outputLocationsName,
    final String outputScoresName) {
  final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();

  d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);

  final Graph g = d.inferenceInterface.graph();

  d.inputName = inputName;
  // The inputName node has a shape of [N, H, W, C], where
  // N is the batch size
  // H = W are the height and width
  // C is the number of channels (3 for our purposes - RGB)
  final Operation inputOp = g.operation(inputName);
  if (inputOp == null) {
    throw new RuntimeException("Failed to find input Node '" + inputName + "'");
  d.inputSize = (int) inputOp.output(0).shape().size(1);
  d.imageMean = imageMean;
  d.imageStd = imageStd;
  // The outputScoresName node has a shape of [N, NumLocations], where N
  // is the batch size.
  final Operation outputOp = g.operation(outputScoresName);
  if (outputOp == null) {
    throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'");
  d.numLocations = (int) outputOp.output(0).shape().size(1);

  d.boxPriors = new float[d.numLocations * 8];

  try {
    d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
  } catch (final IOException e) {
    throw new RuntimeException("Error initializing box priors from " + locationFilename);

  // Pre-allocate buffers.
  d.outputNames = new String[] {outputLocationsName, outputScoresName};
  d.intValues = new int[d.inputSize * d.inputSize];
  d.floatValues = new float[d.inputSize * d.inputSize * 3];
  d.outputScores = new float[d.numLocations];
  d.outputLocations = new float[d.numLocations * 4];

  return d;

示例11: getGraphOperation

import org.tensorflow.Operation; //导入依赖的package包/类
private Operation getGraphOperation(String id, String name) {
    return getReactApplicationContext().getNativeModule(RNTensorFlowGraphModule.class).getOperation(id, name);

示例12: loadNetwork

import org.tensorflow.Operation; //导入依赖的package包/类
    public void loadNetwork(File f) throws IOException {
        if (f == null) {
            throw new IOException("null file");
        try {
            graphDef = Files.readAllBytes(Paths.get(f.getAbsolutePath())); // "tensorflow_inception_graph.pb"
            executionGraph = new Graph();
            Iterator<Operation> itr = executionGraph.operations();
            StringBuilder b = new StringBuilder("TensorFlow Graph: \n");
            int opnum = 0;
            while (itr.hasNext()) {
                Operation o = itr.next();
                final String s = o.toString().toLowerCase();
//                if(s.contains("input") || s.contains("output") || s.contains("placeholder")){
                if (s.contains("input") || s.contains("placeholder") || s.contains("output")) {  // find input placeholder & output
//                    int numOutputs = o.numOutputs();
                    b.append("********** ");
//                    for (int onum = 0; onum < numOutputs; onum++) {
//                        Output output = o.output(onum);
//                        Shape shape = output.shape();
//                        int numDimensions = shape.numDimensions();
//                        for (int dimidx = 0; dimidx < numDimensions; dimidx++) {
//                            long dim = shape.size(dimidx);
//                        }
//                    }
//                    int inputLength=o.inputListLength("");
                b.append(opnum++ + ": " + o.toString() + "\n");
        } catch (Exception e) {

示例13: create

import org.tensorflow.Operation; //导入依赖的package包/类
 * Initializes a native TensorFlow session for classifying images.
 * @param assetManager The asset manager to be used to load assets.
 * @param modelFilename The filepath of the model GraphDef protocol buffer.
 * @param locationFilename The filepath of label file for classes.
 * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
 * @param imageMean The assumed mean of the image values.
 * @param imageStd The assumed std of the image values.
 * @param inputName The label of the image input node.
 * @param outputName The label of the output node.
public static Classifier create(
    final AssetManager assetManager,
    final String modelFilename,
    final String locationFilename,
    final int imageMean,
    final float imageStd,
    final String inputName,
    final String outputLocationsName,
    final String outputScoresName) {
  final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();

  d.inferenceInterface = new TensorFlowInferenceInterface();
  if (d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
    throw new RuntimeException("TF initialization failed");

  final Graph g = d.inferenceInterface.graph();

  d.inputName = inputName;
  // The inputName node has a shape of [N, H, W, C], where
  // N is the batch size
  // H = W are the height and width
  // C is the number of channels (3 for our purposes - RGB)
  final Operation inputOp = g.operation(inputName);
  if (inputOp == null) {
    throw new RuntimeException("Failed to find input Node '" + inputName + "'");
  d.inputSize = (int) inputOp.output(0).shape().size(1);
  d.imageMean = imageMean;
  d.imageStd = imageStd;
  // The outputScoresName node has a shape of [N, NumLocations], where N
  // is the batch size.
  final Operation outputOp = g.operation(outputScoresName);
  if (outputOp == null) {
    throw new RuntimeException("Failed to find output Node '" + outputScoresName + "'");
  d.numLocations = (int) outputOp.output(0).shape().size(1);

  d.boxPriors = new float[d.numLocations * 8];

  try {
    d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
  } catch (final IOException e) {
    throw new RuntimeException("Error initializing box priors from " + locationFilename);

  // Pre-allocate buffers.
  d.outputNames = new String[] {outputLocationsName, outputScoresName};
  d.intValues = new int[d.inputSize * d.inputSize];
  d.floatValues = new float[d.inputSize * d.inputSize * 3];
  d.outputScores = new float[d.numLocations];
  d.outputLocations = new float[d.numLocations * 4];

  return d;

示例14: ensureDataField

import org.tensorflow.Operation; //导入依赖的package包/类
public DataField ensureDataField(SavedModel savedModel, NodeDef placeholder){

			throw new IllegalArgumentException(placeholder.getName());

		FieldName name = FieldName.create(placeholder.getName());

		DataField dataField = getDataField(name);
		if(dataField == null){
			Operation operation = savedModel.getOperation(placeholder.getName());

			Output output = operation.output(0);

			dataField = createDataField(name, TypeUtil.getOpType(output), TypeUtil.getDataType(output));

		return dataField;

示例15: createContinuousFeature

import org.tensorflow.Operation; //导入依赖的package包/类
public ContinuousFeature createContinuousFeature(SavedModel savedModel, NodeDef placeholder){
	NodeDef cast = null;

		cast = placeholder;
		placeholder = savedModel.getNodeDef(placeholder.getInput(0));

	DataField dataField = ensureContinuousDataField(savedModel, placeholder);

	ContinuousFeature result = new ContinuousFeature(this, dataField);

	if(cast != null){
		Operation operation = savedModel.getOperation(cast.getName());

		Output output = operation.output(0);

		result = result.toContinuousFeature(TypeUtil.getDataType(output));

	return result;
