在实际工程中,最常见的模式是:算法团队用 Python/PyTorch/Keras 训练模型,工程团队用 Java/Spring Boot 部署上线。
这就涉及到一个问题:怎么把 Python 的模型给 Java 用?
Deeplearning4j (DL4J) 提供了非常完善的模型导入功能,支持 Keras (H5) 和 TensorFlow (SavedModel) 格式。
1. Python 端:保存模型
假设你的算法同事用 Keras 训练了一个简单的分类模型:
# Python 代码
from tensorflow import keras
# ... 训练过程 ...
# 保存整个模型(包含结构和权重)
model.save('my_weather_model.h5')
2. Java 端:引入依赖
除了核心的 deeplearning4j-core,你需要加一个 model-import 模块:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-modelimport</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
3. Java 端:加载模型
加载代码简直简单得令人发指,只需要一行:
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
public class ModelLoader {
public static void main(String[] args) throws Exception {
String modelPath = "path/to/my_weather_model.h5";
// 加载模型 (false 表示不强制在用户端进行训练配置,适合仅做推理)
MultiLayerNetwork net = KerasModelImport.importKerasSequentialModelAndWeights(modelPath, false);
System.out.println("模型加载成功!");
System.out.println(net.summary());
}
}
4. 推理 (Inference)
模型加载进来了,怎么用呢?
你需要把输入数据转换成 INDArray(即 ND4J 的 Tensor)。
假设模型的输入是一个长度为 3 的向量 [温度, 湿度, 风速]:
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ndarray.INDArray;
// ...
// 构造输入数据:1行3列
INDArray input = Nd4j.create(new float[]{25.5f, 0.65f, 3.2f}, new int[]{1, 3});
// 预测
INDArray output = net.output(input);
System.out.println("预测结果:" + output);
常见坑
- 版本对应:Keras 版本更新很快,DL4J 的导入模块有时候跟不上最新的 Keras 格式。建议在 Python 端保存时使用 Keras 2.x 的通用格式。
- 图片预处理:如果输入是图片,Python 端通常做了归一化(除以255),Java 端读取图片后也要做完全一致的操作,否则预测结果会不准。可以使用
NativeImageLoader来辅助。
总结
通过 DL4J 的导入功能,我们完美实现了“Python 训练,Java 上线”的闭环。这对于已有大量 Java 基础设施的公司来说,是成本最低的 AI 落地方式。