5 分鐘用 Java 完成 PyTorch 物體識別!
近年來 PyTorch 在深度學習領域中的應用日趨廣泛。主要的原因為其以 Python 為主的介面設計提供了絕佳的使用體驗。使用者可以用簡單易懂的 Python 指令快速實作數據的平行處理、深度學習模型的訓練及部署、甚至是動態計算圖等功能。因此近年越來越多機器學習相關的學術論文選擇使用 PyTorch 作為實驗框架 (framework) 以驗證新的深度學習模型及訓練方法。
部署一個 PyTorch 模型有很多選擇,但是對 Java 開發人員友善的選項卻屈指可數。在過去,Java 開發者可以透過JNI (Java Native Interface) 呼叫 PyTorch C++ API 來解決這個問題,但需要開發者自行處理介面(interface)及資料的傳輸。雖然最近 PyTorch 1.4 發佈了實驗性的 Java API,但是並不支援運算子 (operator)。針對於這個問題,亞馬遜雲端服務 (AWS) 開源了 Deep Java Library (DJL),一個為Java開發者設計的深度學習庫。它兼顧了易用性和可維護性,運行效率以及記憶體管理問題都得到了很好的處理。DJL 使用起來非常簡單。只需幾行程式碼,用戶就可以輕鬆部署深度學習模型用於推論 (inference)。現在,就讓我們開始用 DJL 部署一個 PyTorch 模型吧!
▍在開始之前
使用者可以用 Maven 或者 Gradle 等 Java 常用的專案建置自動化工具來開發 DJL。以下是 build.gradle 範例:
plugins {
id 'java'
}
repositories {
jcenter()
}
dependencies {
implementation "ai.djl:api:0.6.0"
implementation "ai.djl:repository:0.6.0"
runtimeOnly "ai.djl.pytorch:pytorch-model-zoo:0.6.0"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto:1.5.0"
}
接著只需 gradle build,基本設定就大功告成了。
▍部署模型
我們使用 NVIDIA 在 TorchHub 發佈的預訓練模型來推論 (inference) 下面這張圖中幾個可以識別的物體(狗,腳踏車以及卡車)。
我們可以透過下面的程式碼來實作推論 (inference) 的邏輯:
public static void main(String[] args) throws IOException, ModelException, TranslateException {
// 讀取一張圖片
Path imageFile = Paths.get("https://github.com/awslabs/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);
// 創建一個搜尋模型的條件
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
// 設定應用類型:物體識別
.optApplication(Application.CV.OBJECT_DETECTION)
// 確定輸入輸出類型(使用預設的圖片處理工具)
.setTypes(Image.class, DetectedObjects.class)
// 模型的過濾條件
.optFilter("backbone", "resnet50")
.optProgress(new ProgressBar())
.build();
// 創建一個模型
try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
// 創建一個 Predictor
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
// 識別物體
DetectedObjects detection = predictor.predict(img);
System.out.println(detection);
}
}
}
執行上面的程式,你會看到包含物體的標籤 (labels) 及其對應機率的輸出結果。
[
class: "dog", probability: 0.96709, bounds: [x=0.165, y=0.348, width=0.249, height=0.539]
class: "bicycle", probability: 0.66796, bounds: [x=0.152, y=0.244, width=0.574, height=0.562]
class: "truck", probability: 0.64912, bounds: [x=0.609, y=0.132, width=0.284, height=0.166]
]
除此之外,我們也可以使用 DJL 提供的物體識別圖形化 API 將實際推論 (inference) 的結果視覺化:
DJL 目前擁有一個包含 70 多種模型的模型庫 (ModelZoo),其中模型來自於 GluonCV, TorchHub, Keras, HuggingFace 等深度學習框架的電腦視覺及自然語言處理預訓練模型。我們提供了幾種常用的前處理與後處理類別 (class) ,使用者也能實作客製化的前處理與後處理,便能輕鬆導入模型庫中的模型。我們還在不斷的擴充各種新的預訓練模型使 DJL 的模型庫更實用。
▍DJL 是什麼?
DJL 是亞馬遜雲端服務在 2019 年 re:Invent 大會推出專為 Java 開發者量身打造的深度學習框架 (framework),現已廣泛運行在亞馬遜的服務中。DJL 的主要特色有以下三點:
DJL 支援各種深度學習框架 (framework)。使用者可以輕鬆的使用 Java 呼叫 MXNet, PyTorch, TensorFlow, fastText, ONNX Runtime 做模型訓練和推論,即使切換框架 (framework) 也不會改變結果。
DJL 運算子的設計參照了 NumPy,所以在使用體驗上和 NumPy 非常相似,亦非常好上手。
DJL 擁有優秀的記憶體管理機制。即使 100 個小時連續推論也不會有記憶體不足的情況。
James Gosling (Java 創始人) 在使用 DJL 後給出了讚譽: