大多数企业级应用仍然跑在Java上,但机器学习生态却几乎被Python垄断。模型训练完成后,想在生产环境里跑起来,往往得套一层REST服务或者微前端架构——这带来的延迟、复杂度和维护成本,任哪个架构师看了都会头疼。
有没有办法让AI模型直接在JVM里跑?答案是肯定的。本文要聊的ONNX(Open Neural Network Exchange)就是这个问题的解法。它让你能把PyTorch或Hugging Face上训练的模型,直接在Java环境里做推理,不需要Python进程,不需要额外的容器,也不用改写核心业务代码。
为什么企业架构师应该关注ONNX
先说清楚几个现实问题。Python模型要上生产,通常得走这条路:训练→导出→部署成gRPC/REST服务→Java应用调用。这中间多了个网络调用,多了套Python运行时,多了套监控体系,出了问题排查起来也麻烦。
ONNX的价值在于:它提供了一种标准化的模型格式,主流框架(PyTorch、TensorFlow、Hugging Face)都能导出,导出后的模型可以直接在Java里用ONNX Runtime加载运行。这意味着推理逻辑和业务代码在同一个JVM里,共享同一套线程池、监控、日志体系,排查问题直接在同一个堆栈里搞定。
对架构师来说,ONNX解锁了几个关键能力:
- 语言一致性:推理逻辑跑在JVM内部,不是独立进程
- 部署简化:不需要管理Python运行时,也不用维护REST代理层
- 基础设施复用:直接用现有的Java监控、追踪、安全控制
- GPU/CPU灵活切换:同一套代码,开发机用CPU,生产环境切GPU
系统架构设计
把AI推理集成到企业Java系统,不是简单扔个模型文件就完事了。好的设计应该把整个推理流程拆成职责清晰的独立组件,每个组件可测试、可替换、可独立部署。
整体架构大概是这个样子:
- 输入适配层:接收REST请求、Kafka消息、文件等各种来源的原始数据
- 分词器(Tokenizer):把文本转成模型能理解的数值序列,这一步依赖Hugging Face兼容的tokenizer.json文件
- ONNX推理引擎:调用ONNX Runtime执行模型推理,自动选择CPU或GPU后端
- 后处理模块:把模型的原始输出(logits)转成业务可用的实体(标签、分类、抽取的实体等)
- 输出路由:把结果写回数据库、推给下游系统或返回HTTP响应
这种分层设计的好处是:分词器可以换(不同任务可能需要不同的分词器),模型可以换(升级或切换任务类型),后端可以换(CPU切GPU不需要改业务代码),每个环节都能独立测试和监控。
依赖配置与项目初始化
先说Maven依赖,这一步没什么花样,直接把ONNX Runtime和Hugging Face的分词器库加进来就行:
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.18.0</version>
</dependency>
<dependency>
<groupId>com.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.16.0</version>
</dependency>
如果要在生产环境用GPU,需要把onnxruntime改成onnxruntime_gpu,版本号保持一致。ONNX Runtime的自动发现机制会根据运行环境选择最优的执行提供程序(Execution Provider),有GPU就用CUPN,没有就回退到CPU。
分词器实现:Hugging Face兼容
分词器是整个流程里最容易出问题的环节。很多时候模型推理不准,不是模型本身的问题,而是分词器和模型训练时用的不一致。架构上一定要把分词器当成版本化管理的一等公民,和模型文件绑定在一起。
下面是个实用的分词器封装,支持从classpath或文件系统加载Hugging Face格式的tokenizer.json:
public class TokenizerService {
private final Tokenizer tokenizer;
public TokenizerService(String tokenizerPath) throws IOException {
try (InputStream is = loadResource(tokenizerPath)) {
this.tokenizer = Tokenizer.fromFile(is);
}
}
private static InputStream loadResource(String path) throws IOException {
// 优先从classpath加载,支持JAR包内打包
InputStream is = TokenizerService.class.getClassLoader().getResourceAsStream(path);
if (is == null) {
// 回退到文件系统路径
return new FileInputStream(path);
}
return is;
}
/**
* 将文本编码为模型输入ID
*/
public Encoding encode(String text) {
return tokenizer.encode(text);
}
/**
* 批量编码,支持长文本自动分块
*/
public List<Encoding> encodeBatch(List<String> texts) {
return tokenizer.encodeBatch(texts);
}
}
这个封装有几个考量:优先从classpath加载意味着tokenizer.json可以直接打包进JAR,部署时少一个文件需要管理;返回的Encoding对象包含input IDs、attention mask等模型需要的全部字段,直接传给推理引擎就行。
ONNX推理引擎:核心实现
推理引擎是整个系统的发动机。这里要注意几个点:Session要复用不要重建(初始化成本很高)、多线程访问要线程安全、GPU内存要记得释放。下面是个经过生产验证的实现:
public class OnnxInferenceEngine implements AutoCloseable {
private final OrtEnvironment environment;
private final OrtSession session;
private final String inputName;
private final String outputName;
public OnnxInferenceEngine(String modelPath) throws OrtException {
this.environment = OrtEnvironment.getEnvironment();
// SessionOptions控制执行行为
SessionOptions sessionOptions = new SessionOptions();
// 开启优化(算子融合、内存优化等)
sessionOptions.setOptimizationLevel(OptimizationLevel.ALL_OPT);
// 根据运行环境自动选择CPU或GPU
// 有CUDA就优先用,没有就CPU
try {
sessionOptions.addCUDA();
} catch (OrtException e) {
// 记录日志,回退到CPU
System.err.println("GPU不可用,使用CPU执行: " + e.getMessage());
}
// 设置并行度,CPU模式下可以控制线程数
sessionOptions.setIntraOpNumNumThreads(Runtime.getRuntime().availableProcessors());
// 加载模型
this.session = environment.createSession(modelPath, sessionOptions);
// 获取输入输出名称(简化版假设单输入单输出)
this.inputName = session.getInputNames().iterator().next();
this.outputName = session.getOutputNames().iterator().next();
}
/**
* 执行推理,返回原始logits
*/
public float[][] infer(long[][] inputIds, long[][] attentionMask) throws OrtException {
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put(inputName, OnnxTensor.createTensor(environment, inputIds));
inputs.put("attention_mask", OnnxTensor.createTensor(environment, attentionMask));
try {
OrtSession.Result result = session.run(inputs);
// 提取logits数组
return (float[][]) result.get(0).getValue();
} finally {
// 确保tensor被释放,防止内存泄漏
inputs.values().forEach(OnnxTensor::close);
}
}
@Override
public void close() {
session.close();
environment.close();
}
}
这段代码里有几个关键细节值得说说。首先是GPU自动发现——调用addCUDA()如果失败了(没有GPU或驱动不对),代码会静默回退到CPU,生产环境不会出现启动失败的问题。然后是Tensor生命周期管理——每个OnnxTensor都必须手动close,否则GPU内存会慢慢耗尽,JVM不会自动回收这些native内存。最后是优化级别——ALL_OPT会开启算子融合等优化,这对transformer模型的性能提升很明显。
后处理:把logits转成业务结果
模型输出的是logits,也就是未经softmax的原始分数。后处理的任务是把这些分数转成业务可用的标签或实体。下面以文本分类为例演示后处理逻辑:
public class ClassificationPostProcessor {
private final List<String> labels;
public ClassificationPostProcessor(List<String> labels) {
this.labels = labels;
}
/**
* 从logits中提取最高概率的分类
*/
public ClassificationResult process(float[][] logits) {
float[] scores = softmax(logits[0]);
int predictedClass = argmax(scores);
return new ClassificationResult(
labels.get(predictedClass),
scores[predictedClass],
scores
);
}
private float[] softmax(float[] logits) {
float max = Float.NEGATIVE_INFINITY;
float sum = 0;
for (float logit : logits) {
max = Math.max(max, logit);
}
float[] exp = new float[logits.length];
for (int i = 0; i < logits.length; i++) {
exp[i] = (float) Math.exp(logits[i] - max);
sum += exp[i];
}
for (int i = 0; i < exp.length; i++) {
exp[i] /= sum;
}
return exp;
}
private int argmax(float[] scores) {
int best = 0;
float bestScore = Float.NEGATIVE_INFINITY;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > bestScore) {
bestScore = scores[i];
best = i;
}
}
return best;
}
}
softmax的计算需要注意数值稳定性,先减最大值再exp可以避免溢出。后处理封装成独立类的好处是:单元测试可以直接造logits数据验证后处理逻辑,不需要真实模型参与,测试速度和覆盖率都更有保障。
整合:组装完整的推理服务
把上面的组件串起来,就是一个完整的推理服务。这个服务应该设计成Spring Bean或类似的依赖注入形式,方便集成到现有系统:
@Service
public class TextClassificationService implements InitializingBean, DisposableBean {
private TokenizerService tokenizer;
private OnnxInferenceEngine engine;
private ClassificationPostProcessor postProcessor;
@Value("${model.base-path:/models}")
private String modelBasePath;
@Override
public void afterPropertiesSet() throws Exception {
// 模型和分词器应该从配置文件指定,支持不同环境不同版本
String tokenizerFile = modelBasePath + "/tokenizer.json";
String modelFile = modelBasePath + "/model.onnx";
this.tokenizer = new TokenizerService(tokenizerFile);
this.engine = new OnnxInferenceEngine(modelFile);
this.postProcessor = new ClassificationPostProcessor(List.of(
"正面", "负面", "中性"
));
}
public ClassificationResult classify(String text) throws Exception {
// 1. 分词
Encoding encoding = tokenizer.encode(text);
// 2. 转成模型输入格式
long[][] inputIds = new long[][] {encoding.getIds()};
long[][] attentionMask = new long[][] {encoding.getAttentionMask()};
// 3. 推理
float[][] logits = engine.infer(inputIds, attentionMask);
// 4. 后处理
return postProcessor.process(logits);
}
@Override
public void destroy() {
// 优雅关闭,释放GPU资源
if (engine != null) {
engine.close();
}
}
}
这个服务实现了InitializingBean和DisposableBean,Spring启动时自动加载模型,关闭时自动释放资源。model.base-path从配置文件读取,这样DEV、TEST、PROD环境可以用不同的模型版本,回滚也只需要改配置。
生产环境注意事项
把模型跑起来只是第一步,生产环境要考虑的远比这多。
模型与分词器的版本对齐
这是生产环境最常见的坑。模型升级了但分词器没同步,或者反过来,都会导致推理结果异常。解决方案是把模型和分词器打包在一起,作为同一个部署单元管理。目录结构建议这样:
/models/v1.0.0/
├── model.onnx
├── tokenizer.json
└── config.json # 记录版本号、创建时间等元数据
每次模型发布都带完整的文件集合,部署时整个目录一起替换,避免部分更新导致的不一致。
内存与GPU资源管理
ONNX Runtime创建的Tensor使用的是native内存,不受JVM堆内存管理。生产环境有几个建议:
- 设置-Xmx时不要把内存全分给JVM,留一部分给ONNX Runtime的native堆
- GPU模式下要监控显存使用,可以在应用里暴露/metrics端点
- 长时间运行的服务定期检查是否有Tensor泄漏(对象没close)
并发与线程安全
OrtSession本身是线程安全的,可以多个请求并发调用。但要注意:
- 每次推理创建的OnnxTensor必须及时释放(在finally块里close)
- 分词器可以多线程共享,Hugging Face的tokenizers库是线程安全的
- 如果并发量很高,考虑用对象池复用OnnxTensor,减少创建销毁开销
可观测性
AI推理应该和普通业务逻辑一样可观测。建议埋的点:
- 推理耗时(分词+推理+后处理的分解耗时)
- 模型版本号(从config.json读取)
- 输入输出采样(用于debug和模型监控)
- GPU利用率(如果用GPU的话)
总结与下一步
ONNX让在Java里跑AI推理变成了现实。不用Python进程,不用微服务拆分,模型直接嵌进JVM里,和业务代码共享同一套基础设施。对企业架构师来说,这意味着AI能力可以平滑接入现有的Java体系,不需要为了AI重建整个技术栈。
落地建议:先从CPU环境跑通全流程,验证分词器和模型版本对齐没问题,再逐步引入GPU。模型和分词器一定要版本化管理,作为完整的部署单元一起发布。推理组件设计成可替换的接口层,方便后续切换模型或调整后处理逻辑。
如果你的团队已经在用Spring Boot,这套模式可以直接复用现有的依赖注入和配置管理。如果还没用Spring,核心的OnnxInferenceEngine和TokenizerService也是Plain Java对象,随时可以集成进去。