Amazon DJL:Java AI 的大一统方案

在 Java AI 领域,除了 Deep.

在 Java AI 领域,除了 Deeplearning4j,还有一个重量级玩家:Amazon DJL (Deep Java Library)
它的设计理念非常独特:不造轮子,而是做轮子的统一接口
DJL 底层可以切换 PyTorch, TensorFlow, MXNet, ONNX 等各种引擎,但上层提供一套统一的 Java API。这就好比 JDBC 之于各种数据库。

核心优势

  1. 引擎无关 (Engine Agnostic):你写的一套 Java 代码,换个依赖就能跑在 PyTorch 或 TensorFlow 上。
  2. 模型库 (Model Zoo):内置了大量预训练好的模型(ResNet, BERT, YOLO 等),开箱即用。
  3. 原生 Java 体验:不像某些库只是简单的 C++ Wrapper,DJL 的 API 设计非常符合 Java 最佳实践。

实战:使用 YOLOv5 做目标检测

我们不需要自己训练模型,直接从 Model Zoo 里加载一个现成的。

1. 引入依赖

<dependencies>
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.28.0</version>
    </dependency>
    <!-- 选择 PyTorch 作为底层引擎 -->
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-engine</artifactId>
        <version>0.28.0</version>
    </dependency>
    <!-- 模型库支持 -->
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-model-zoo</artifactId>
        <version>0.28.0</version>
    </dependency>
</dependencies>

2. 加载模型并预测

import ai.djl.Application;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;

import java.nio.file.Path;
import java.nio.file.Paths;

public class ObjectDetection {
    public static void main(String[] args) throws Exception {
        // 1. 读取图片
        Path imagePath = Paths.get("dog.jpg");
        Image img = ImageFactory.getInstance().fromFile(imagePath);

        // 2. 定义筛选条件:我们需要一个目标检测模型
        Criteria<Image, DetectedObjects> criteria = Criteria.builder()
                .optApplication(Application.CV.OBJECT_DETECTION)
                .setTypes(Image.class, DetectedObjects.class)
                .optFilter("backbone", "resnet50") // 选择 resnet50 骨架
                .build();

        // 3. 加载模型
        try (ZooModel<Image, DetectedObjects> model = criteria.loadModel();
             Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {

            // 4. 预测
            DetectedObjects detection = predictor.predict(img);

            // 5. 打印结果
            System.out.println(detection);
        }
    }
}

这段代码会自动下载模型,并在你的图片上圈出物体(比如 “Dog: 0.99″)。

总结

如果你不想关心底层的数学原理,只想快速把业界最先进的模型(比如 Stable Diffusion, BERT)用到你的 Java 项目里,DJL 是目前最方便的选择。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注