Java AI 部署:如何在 Java 中加载 Keras/Python 模型

在实际工程中,最常见的模式是:算法团队用.

在实际工程中,最常见的模式是:算法团队用 Python/PyTorch/Keras 训练模型,工程团队用 Java/Spring Boot 部署上线
这就涉及到一个问题:怎么把 Python 的模型给 Java 用?

Deeplearning4j (DL4J) 提供了非常完善的模型导入功能,支持 Keras (H5) 和 TensorFlow (SavedModel) 格式。

1. Python 端:保存模型

假设你的算法同事用 Keras 训练了一个简单的分类模型:

# Python 代码
from tensorflow import keras

# ... 训练过程 ...

# 保存整个模型(包含结构和权重)
model.save('my_weather_model.h5')

2. Java 端:引入依赖

除了核心的 deeplearning4j-core,你需要加一个 model-import 模块:

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-modelimport</artifactId>
    <version>1.0.0-M2.1</version>
</dependency>

3. Java 端:加载模型

加载代码简直简单得令人发指,只需要一行:

import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

public class ModelLoader {
    public static void main(String[] args) throws Exception {
        String modelPath = "path/to/my_weather_model.h5";

        // 加载模型 (false 表示不强制在用户端进行训练配置,适合仅做推理)
        MultiLayerNetwork net = KerasModelImport.importKerasSequentialModelAndWeights(modelPath, false);

        System.out.println("模型加载成功!");
        System.out.println(net.summary());
    }
}

4. 推理 (Inference)

模型加载进来了,怎么用呢?
你需要把输入数据转换成 INDArray(即 ND4J 的 Tensor)。

假设模型的输入是一个长度为 3 的向量 [温度, 湿度, 风速]

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ndarray.INDArray;

// ...

// 构造输入数据:1行3列
INDArray input = Nd4j.create(new float[]{25.5f, 0.65f, 3.2f}, new int[]{1, 3});

// 预测
INDArray output = net.output(input);

System.out.println("预测结果:" + output);

常见坑

  1. 版本对应:Keras 版本更新很快,DL4J 的导入模块有时候跟不上最新的 Keras 格式。建议在 Python 端保存时使用 Keras 2.x 的通用格式。
  2. 图片预处理:如果输入是图片,Python 端通常做了归一化(除以255),Java 端读取图片后也要做完全一致的操作,否则预测结果会不准。可以使用 NativeImageLoader 来辅助。

总结

通过 DL4J 的导入功能,我们完美实现了“Python 训练,Java 上线”的闭环。这对于已有大量 Java 基础设施的公司来说,是成本最低的 AI 落地方式。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注