5 分鐘用 Java 完成 PyTorch 物體識別!

近年來 PyTorch 在深度學習領域中的應用日趨廣泛。主要的原因為其以 Python 為主的介面設計提供了絕佳的使用體驗。使用者可以用簡單易懂的 Python 指令快速實作數據的平行處理、深度學習模型的訓練及部署、甚至是動態計算圖等功能。因此近年越來越多機器學習相關的學術論文選擇使用 PyTorch 作為實驗框架 (framework) 以驗證新的深度學習模型及訓練方法。

部署一個 PyTorch 模型有很多選擇,但是對 Java 開發人員友善的選項卻屈指可數。在過去,Java 開發者可以透過JNI (Java Native Interface) 呼叫 PyTorch C++ API 來解決這個問題,但需要開發者自行處理介面(interface)及資料的傳輸。雖然最近 PyTorch 1.4 發佈了實驗性的 Java API,但是並不支援運算子 (operator)。針對於這個問題,亞馬遜雲端服務 (AWS) 開源了 Deep Java Library (DJL),一個為Java開發者設計的深度學習庫。它兼顧了易用性和可維護性,運行效率以及記憶體管理問題都得到了很好的處理。DJL 使用起來非常簡單。只需幾行程式碼,用戶就可以輕鬆部署深度學習模型用於推論 (inference)。現在,就讓我們開始用 DJL 部署一個 PyTorch 模型吧!

▍在開始之前

使用者可以用 Maven 或者 Gradle 等 Java 常用的專案建置自動化工具來開發 DJL。以下是 build.gradle 範例:

plugins {
    id 'java'
}
repositories {                           
    jcenter()
}
dependencies {
    implementation "ai.djl:api:0.6.0"
    implementation "ai.djl:repository:0.6.0"
    runtimeOnly "ai.djl.pytorch:pytorch-model-zoo:0.6.0"
    runtimeOnly "ai.djl.pytorch:pytorch-native-auto:1.5.0"
}

接著只需 gradle build,基本設定就大功告成了。

▍部署模型

我們使用 NVIDIA 在 TorchHub 發佈的預訓練模型來推論 (inference) 下面這張圖中幾個可以識別的物體(狗,腳踏車以及卡車)。

我們可以透過下面的程式碼來實作推論 (inference) 的邏輯:

public static void main(String[] args) throws IOException, ModelException, TranslateException {
    // 讀取一張圖片
    Path imageFile = Paths.get("https://github.com/awslabs/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg");
    Image img = ImageFactory.getInstance().fromFile(imageFile);
    
    // 創建一個搜尋模型的條件
    Criteria<Image, DetectedObjects> criteria =
            Criteria.builder()
                    // 設定應用類型:物體識別
                    .optApplication(Application.CV.OBJECT_DETECTION)
                    // 確定輸入輸出類型(使用預設的圖片處理工具)
                    .setTypes(Image.class, DetectedObjects.class)
                    // 模型的過濾條件
                    .optFilter("backbone", "resnet50")
                    .optProgress(new ProgressBar())
                    .build();
    
    // 創建一個模型
    try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
        // 創建一個 Predictor
        try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
            // 識別物體
            DetectedObjects detection = predictor.predict(img);
            System.out.println(detection);
        }
    }
} 

執行上面的程式,你會看到包含物體的標籤 (labels) 及其對應機率的輸出結果。

[
    class: "dog", probability: 0.96709, bounds: [x=0.165, y=0.348, width=0.249, height=0.539]
    class: "bicycle", probability: 0.66796, bounds: [x=0.152, y=0.244, width=0.574, height=0.562]
    class: "truck", probability: 0.64912, bounds: [x=0.609, y=0.132, width=0.284, height=0.166]
]

除此之外,我們也可以使用 DJL 提供的物體識別圖形化 API 將實際推論 (inference) 的結果視覺化:

DJL 目前擁有一個包含 70 多種模型的模型庫 (ModelZoo),其中模型來自於 GluonCV, TorchHub, Keras, HuggingFace 等深度學習框架的電腦視覺及自然語言處理預訓練模型。我們提供了幾種常用的前處理與後處理類別 (class) ,使用者也能實作客製化的前處理與後處理,便能輕鬆導入模型庫中的模型。我們還在不斷的擴充各種新的預訓練模型使 DJL 的模型庫更實用。

▍DJL 是什麼?

DJL 是亞馬遜雲端服務在 2019 年 re:Invent 大會推出專為 Java 開發者量身打造的深度學習框架 (framework),現已廣泛運行在亞馬遜的服務中。DJL 的主要特色有以下三點:
DJL 支援各種深度學習框架 (framework)。使用者可以輕鬆的使用 Java 呼叫 MXNet, PyTorch, TensorFlow, fastText, ONNX Runtime 做模型訓練和推論,即使切換框架 (framework) 也不會改變結果。

DJL 運算子的設計參照了 NumPy,所以在使用體驗上和 NumPy 非常相似,亦非常好上手。

DJL 擁有優秀的記憶體管理機制。即使 100 個小時連續推論也不會有記憶體不足的情況。

James Gosling (Java 創始人) 在使用 DJL 後給出了讚譽:

 免費註冊 AWS 帳號

新戶註冊即享 AWS 免費方案,可探索超過 100 種 AWS 的產品與服務,還能加碼領取獨家贈品!

 與我們聯絡

若欲尋求技術、帳單帳戶、登入存取支援,或希望與 AWS 的雲端業務聯絡,都竭誠歡迎您與我們聯繫!

 探索台灣資源中心

集結研討會精采回顧雲端主題白皮書開始上雲系列等免費資源,進一步豐富您的雲端之旅。