當前位置: 首頁>>技術教程>>正文


安裝TensorFlow for Java

TensorFlow提供用於Java程序的API。這些API非常適合於加載在Python中創建的模型並在Java應用程序中執行。本指南說明如何安裝TensorFlow for Java並在Java應用程序中使用它。

警告:TensorFlow Java API目前沒有TensorFlowAPI穩定性保證

支持的平台

以下操作係統支持TensorFlow for Java:

  • Linux
  • Mac OS X
  • Windows
  • Android的

Android的安裝說明在單獨的文檔中:Android TensorFlow支持頁麵。安裝完成後請看這個TensorFlow在Android上使用的完整的例子的。

在Maven項目中使用TensorFlow

如果您的項目使用Apache Maven,需要將以下內容添加到項目中pom.xml,才能使用TensorFlow Java API:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.3.0</version>
</dependency>

示例

以下步驟將創建一個使用TensorFlow的Maven項目:

  1. 創建項目的pom.xml

     <project>
         <modelVersion>4.0.0</modelVersion>
         <groupId>org.myorg</groupId>
         <artifactId>hellotf</artifactId>
         <version>1.0-SNAPSHOT</version>
         <properties>
           <exec.mainClass>HelloTF</exec.mainClass>
           <!-- The sample code requires at least JDK 1.7. -->
           <!-- The maven compiler plugin defaults to a lower version -->
           <maven.compiler.source>1.7</maven.compiler.source>
           <maven.compiler.target>1.7</maven.compiler.target>
         </properties>
         <dependencies>
           <dependency>
             <groupId>org.tensorflow</groupId>
             <artifactId>tensorflow</artifactId>
             <version>1.3.0</version>
           </dependency>
         </dependencies>
     </project>
    
  2. 創建源文件(src/main/java/HelloTF.java):

    import org.tensorflow.Graph;
    import org.tensorflow.Session;
    import org.tensorflow.Tensor;
    import org.tensorflow.TensorFlow;
    
    public class HelloTF {
      public static void main(String[] args) throws Exception {
        try (Graph g = new Graph()) {
          final String value = "Hello from " + TensorFlow.version();
    
          // Construct the computation graph with a single operation, a constant
          // named "MyConst" with a value "value".
          try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
            // The Java API doesn't yet include convenience functions for adding operations.
            g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
          }
    
          // Execute the "MyConst" operation in a Session.
          try (Session s = new Session(g);
               Tensor output = s.runner().fetch("MyConst").run().get(0)) {
            System.out.println(new String(output.bytesValue(), "UTF-8"));
          }
        }
      }
    }
    
  3. 編譯執行:

     # Use -q to hide logging from the mvn tool
     mvn -q compile exec:java

上述命令輸出Hello from 版本。如果沒有,請檢查Stack Overflow尋找可能的解決方案。

如果隻使用Maven開發,您可以跳過閱讀本文檔剩餘的部分。

在JDK中使用TensorFlowK

本節介紹如何基於來自己JDK的命令javajavac來使用TensorFlow。如果您的項目使用Apache Maven,那麽請參考上麵更簡單的說明。

在Linux或Mac OS上安裝

采取以下步驟在Linux或Mac OS上安裝TensorFlow for Java:

  1. 下載libtensorflow.jar,它是TensorFlow Java Archive(JAR)。

  2. 選擇要不要GPU支持。如何選擇?請閱讀以下指南中的“確定要安裝的TensorFlow”的部分。

  3. 通過運行以下shell命令,為您的操作係統和處理器下載並提取適當的Java Native Interface(JNI)文件:

     TF_TYPE="cpu" # Default processor is CPU. If you want GPU, set to "gpu"
     OS=$(uname -s | tr '[:upper:]' '[:lower:]')
     mkdir -p ./jni
     curl -L \
       "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.3.0.tar.gz" |
       tar -xz -C ./jni
    

在Windows上安裝

采取以下步驟在Windows上安裝TensorFlow for Java:

  1. 下載libtensorflow.jar,它是TensorFlow Java Archive(JAR)。
  2. 下載適合的Java Native Interface(JNI)文件Windows的TensorFlow for Java
  3. 解壓縮此.zip文件。

驗證安裝

安裝TensorFlow for Java後,通過將以下代碼輸入到名為HelloTF.java的文件中來驗證安裝:

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

public class HelloTF {
  public static void main(String[] args) throws Exception {
    try (Graph g = new Graph()) {
      final String value = "Hello from " + TensorFlow.version();

      // Construct the computation graph with a single operation, a constant
      // named "MyConst" with a value "value".
      try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
        // The Java API doesn't yet include convenience functions for adding operations.
        g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
      }

      // Execute the "MyConst" operation in a Session.
      try (Session s = new Session(g);
           Tensor output = s.runner().fetch("MyConst").run().get(0)) {
        System.out.println(new String(output.bytesValue(), "UTF-8"));
      }
    }
  }
}

使用下麵的說明編譯並運行HelloTF.java

編譯

當編譯使用TensorFlow的Java程序時,下載的.jar需要添加到classpath。例如,您也可以通過使用-cp編譯標誌將下載的.jar指定到classpath,如下:

javac -cp libtensorflow-1.3.0.jar HelloTF.java

運行

要執行依賴於TensorFlow的Java程序,請確保以下兩個文件可用於JVM:

  • 下載的.jar文件
  • 提取的JNI庫

例如,在Linux和Mac OS X上以下命令行運行HelloTF程序:

java -cp libtensorflow-1.3.0.jar:. -Djava.library.path=./jni HelloTF

在Windows上執行以下命令行運行HelloTF的程序:

java -cp libtensorflow-1.3.0.jar;. -Djava.library.path=jni HelloTF

上述命令輸出Hello from 版本。如果程序輸出其他內容,請檢查Stack Overflow尋找可能的解決方案

高級示例

有關更複雜的例子,請參閱LabelImage.java,這個例子用於圖像中實體檢測。

從源代碼構建

TensorFlow是開源的,所以可以從TensorFlow源代碼中構建TensorFlow for Java,方法見文檔Tensorflor JAVA安裝

參考資料

本文由《純淨天空》出品。文章地址: https://vimsky.com/zh-tw/article/3623.html,未經允許,請勿轉載。