在Java中运行AI推理:企业架构师的ONNX实战指南

大多数企业级应用仍然跑在Java上,但机.

大多数企业级应用仍然跑在Java上,但机器学习生态却几乎被Python垄断。模型训练完成后,想在生产环境里跑起来,往往得套一层REST服务或者微前端架构——这带来的延迟、复杂度和维护成本,任哪个架构师看了都会头疼。

有没有办法让AI模型直接在JVM里跑?答案是肯定的。本文要聊的ONNX(Open Neural Network Exchange)就是这个问题的解法。它让你能把PyTorch或Hugging Face上训练的模型,直接在Java环境里做推理,不需要Python进程,不需要额外的容器,也不用改写核心业务代码。

为什么企业架构师应该关注ONNX

先说清楚几个现实问题。Python模型要上生产,通常得走这条路:训练→导出→部署成gRPC/REST服务→Java应用调用。这中间多了个网络调用,多了套Python运行时,多了套监控体系,出了问题排查起来也麻烦。

ONNX的价值在于:它提供了一种标准化的模型格式,主流框架(PyTorch、TensorFlow、Hugging Face)都能导出,导出后的模型可以直接在Java里用ONNX Runtime加载运行。这意味着推理逻辑和业务代码在同一个JVM里,共享同一套线程池、监控、日志体系,排查问题直接在同一个堆栈里搞定。

对架构师来说,ONNX解锁了几个关键能力:

  • 语言一致性:推理逻辑跑在JVM内部,不是独立进程
  • 部署简化:不需要管理Python运行时,也不用维护REST代理层
  • 基础设施复用:直接用现有的Java监控、追踪、安全控制
  • GPU/CPU灵活切换:同一套代码,开发机用CPU,生产环境切GPU

系统架构设计

把AI推理集成到企业Java系统,不是简单扔个模型文件就完事了。好的设计应该把整个推理流程拆成职责清晰的独立组件,每个组件可测试、可替换、可独立部署。

整体架构大概是这个样子:

  1. 输入适配层:接收REST请求、Kafka消息、文件等各种来源的原始数据
  2. 分词器(Tokenizer):把文本转成模型能理解的数值序列,这一步依赖Hugging Face兼容的tokenizer.json文件
  3. ONNX推理引擎:调用ONNX Runtime执行模型推理,自动选择CPU或GPU后端
  4. 后处理模块:把模型的原始输出(logits)转成业务可用的实体(标签、分类、抽取的实体等)
  5. 输出路由:把结果写回数据库、推给下游系统或返回HTTP响应

这种分层设计的好处是:分词器可以换(不同任务可能需要不同的分词器),模型可以换(升级或切换任务类型),后端可以换(CPU切GPU不需要改业务代码),每个环节都能独立测试和监控。

依赖配置与项目初始化

先说Maven依赖,这一步没什么花样,直接把ONNX Runtime和Hugging Face的分词器库加进来就行:

<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.18.0</version>
</dependency>

<dependency>
    <groupId>com.huggingface</groupId>
    <artifactId>tokenizers</artifactId>
    <version>0.16.0</version>
</dependency>

如果要在生产环境用GPU,需要把onnxruntime改成onnxruntime_gpu,版本号保持一致。ONNX Runtime的自动发现机制会根据运行环境选择最优的执行提供程序(Execution Provider),有GPU就用CUPN,没有就回退到CPU。

分词器实现:Hugging Face兼容

分词器是整个流程里最容易出问题的环节。很多时候模型推理不准,不是模型本身的问题,而是分词器和模型训练时用的不一致。架构上一定要把分词器当成版本化管理的一等公民,和模型文件绑定在一起。

下面是个实用的分词器封装,支持从classpath或文件系统加载Hugging Face格式的tokenizer.json:

public class TokenizerService {
    
    private final Tokenizer tokenizer;
    
    public TokenizerService(String tokenizerPath) throws IOException {
        try (InputStream is = loadResource(tokenizerPath)) {
            this.tokenizer = Tokenizer.fromFile(is);
        }
    }
    
    private static InputStream loadResource(String path) throws IOException {
        // 优先从classpath加载,支持JAR包内打包
        InputStream is = TokenizerService.class.getClassLoader().getResourceAsStream(path);
        if (is == null) {
            // 回退到文件系统路径
            return new FileInputStream(path);
        }
        return is;
    }
    
    /**
     * 将文本编码为模型输入ID
     */
    public Encoding encode(String text) {
        return tokenizer.encode(text);
    }
    
    /**
     * 批量编码,支持长文本自动分块
     */
    public List<Encoding> encodeBatch(List<String> texts) {
        return tokenizer.encodeBatch(texts);
    }
}

这个封装有几个考量:优先从classpath加载意味着tokenizer.json可以直接打包进JAR,部署时少一个文件需要管理;返回的Encoding对象包含input IDs、attention mask等模型需要的全部字段,直接传给推理引擎就行。

ONNX推理引擎:核心实现

推理引擎是整个系统的发动机。这里要注意几个点:Session要复用不要重建(初始化成本很高)、多线程访问要线程安全、GPU内存要记得释放。下面是个经过生产验证的实现:

public class OnnxInferenceEngine implements AutoCloseable {
    
    private final OrtEnvironment environment;
    private final OrtSession session;
    private final String inputName;
    private final String outputName;
    
    public OnnxInferenceEngine(String modelPath) throws OrtException {
        this.environment = OrtEnvironment.getEnvironment();
        
        // SessionOptions控制执行行为
        SessionOptions sessionOptions = new SessionOptions();
        
        // 开启优化(算子融合、内存优化等)
        sessionOptions.setOptimizationLevel(OptimizationLevel.ALL_OPT);
        
        // 根据运行环境自动选择CPU或GPU
        // 有CUDA就优先用,没有就CPU
        try {
            sessionOptions.addCUDA();
        } catch (OrtException e) {
            // 记录日志,回退到CPU
            System.err.println("GPU不可用,使用CPU执行: " + e.getMessage());
        }
        
        // 设置并行度,CPU模式下可以控制线程数
        sessionOptions.setIntraOpNumNumThreads(Runtime.getRuntime().availableProcessors());
        
        // 加载模型
        this.session = environment.createSession(modelPath, sessionOptions);
        
        // 获取输入输出名称(简化版假设单输入单输出)
        this.inputName = session.getInputNames().iterator().next();
        this.outputName = session.getOutputNames().iterator().next();
    }
    
    /**
     * 执行推理,返回原始logits
     */
    public float[][] infer(long[][] inputIds, long[][] attentionMask) throws OrtException {
        Map<String, OnnxTensor> inputs = new HashMap<>();
        
        inputs.put(inputName, OnnxTensor.createTensor(environment, inputIds));
        inputs.put("attention_mask", OnnxTensor.createTensor(environment, attentionMask));
        
        try {
            OrtSession.Result result = session.run(inputs);
            // 提取logits数组
            return (float[][]) result.get(0).getValue();
        } finally {
            // 确保tensor被释放,防止内存泄漏
            inputs.values().forEach(OnnxTensor::close);
        }
    }
    
    @Override
    public void close() {
        session.close();
        environment.close();
    }
}

这段代码里有几个关键细节值得说说。首先是GPU自动发现——调用addCUDA()如果失败了(没有GPU或驱动不对),代码会静默回退到CPU,生产环境不会出现启动失败的问题。然后是Tensor生命周期管理——每个OnnxTensor都必须手动close,否则GPU内存会慢慢耗尽,JVM不会自动回收这些native内存。最后是优化级别——ALL_OPT会开启算子融合等优化,这对transformer模型的性能提升很明显。

后处理:把logits转成业务结果

模型输出的是logits,也就是未经softmax的原始分数。后处理的任务是把这些分数转成业务可用的标签或实体。下面以文本分类为例演示后处理逻辑:

public class ClassificationPostProcessor {
    
    private final List<String> labels;
    
    public ClassificationPostProcessor(List<String> labels) {
        this.labels = labels;
    }
    
    /**
     * 从logits中提取最高概率的分类
     */
    public ClassificationResult process(float[][] logits) {
        float[] scores = softmax(logits[0]);
        int predictedClass = argmax(scores);
        
        return new ClassificationResult(
            labels.get(predictedClass),
            scores[predictedClass],
            scores
        );
    }
    
    private float[] softmax(float[] logits) {
        float max = Float.NEGATIVE_INFINITY;
        float sum = 0;
        
        for (float logit : logits) {
            max = Math.max(max, logit);
        }
        
        float[] exp = new float[logits.length];
        for (int i = 0; i < logits.length; i++) {
            exp[i] = (float) Math.exp(logits[i] - max);
            sum += exp[i];
        }
        
        for (int i = 0; i < exp.length; i++) {
            exp[i] /= sum;
        }
        
        return exp;
    }
    
    private int argmax(float[] scores) {
        int best = 0;
        float bestScore = Float.NEGATIVE_INFINITY;
        
        for (int i = 0; i < scores.length; i++) {
            if (scores[i] > bestScore) {
                bestScore = scores[i];
                best = i;
            }
        }
        
        return best;
    }
}

softmax的计算需要注意数值稳定性,先减最大值再exp可以避免溢出。后处理封装成独立类的好处是:单元测试可以直接造logits数据验证后处理逻辑,不需要真实模型参与,测试速度和覆盖率都更有保障。

整合:组装完整的推理服务

把上面的组件串起来,就是一个完整的推理服务。这个服务应该设计成Spring Bean或类似的依赖注入形式,方便集成到现有系统:

@Service
public class TextClassificationService implements InitializingBean, DisposableBean {
    
    private TokenizerService tokenizer;
    private OnnxInferenceEngine engine;
    private ClassificationPostProcessor postProcessor;
    
    @Value("${model.base-path:/models}")
    private String modelBasePath;
    
    @Override
    public void afterPropertiesSet() throws Exception {
        // 模型和分词器应该从配置文件指定,支持不同环境不同版本
        String tokenizerFile = modelBasePath + "/tokenizer.json";
        String modelFile = modelBasePath + "/model.onnx";
        
        this.tokenizer = new TokenizerService(tokenizerFile);
        this.engine = new OnnxInferenceEngine(modelFile);
        this.postProcessor = new ClassificationPostProcessor(List.of(
            "正面", "负面", "中性"
        ));
    }
    
    public ClassificationResult classify(String text) throws Exception {
        // 1. 分词
        Encoding encoding = tokenizer.encode(text);
        
        // 2. 转成模型输入格式
        long[][] inputIds = new long[][] {encoding.getIds()};
        long[][] attentionMask = new long[][] {encoding.getAttentionMask()};
        
        // 3. 推理
        float[][] logits = engine.infer(inputIds, attentionMask);
        
        // 4. 后处理
        return postProcessor.process(logits);
    }
    
    @Override
    public void destroy() {
        // 优雅关闭,释放GPU资源
        if (engine != null) {
            engine.close();
        }
    }
}

这个服务实现了InitializingBean和DisposableBean,Spring启动时自动加载模型,关闭时自动释放资源。model.base-path从配置文件读取,这样DEV、TEST、PROD环境可以用不同的模型版本,回滚也只需要改配置。

生产环境注意事项

把模型跑起来只是第一步,生产环境要考虑的远比这多。

模型与分词器的版本对齐

这是生产环境最常见的坑。模型升级了但分词器没同步,或者反过来,都会导致推理结果异常。解决方案是把模型和分词器打包在一起,作为同一个部署单元管理。目录结构建议这样:

/models/v1.0.0/
├── model.onnx
├── tokenizer.json
└── config.json  # 记录版本号、创建时间等元数据

每次模型发布都带完整的文件集合,部署时整个目录一起替换,避免部分更新导致的不一致。

内存与GPU资源管理

ONNX Runtime创建的Tensor使用的是native内存,不受JVM堆内存管理。生产环境有几个建议:

  • 设置-Xmx时不要把内存全分给JVM,留一部分给ONNX Runtime的native堆
  • GPU模式下要监控显存使用,可以在应用里暴露/metrics端点
  • 长时间运行的服务定期检查是否有Tensor泄漏(对象没close)

并发与线程安全

OrtSession本身是线程安全的,可以多个请求并发调用。但要注意:

  • 每次推理创建的OnnxTensor必须及时释放(在finally块里close)
  • 分词器可以多线程共享,Hugging Face的tokenizers库是线程安全的
  • 如果并发量很高,考虑用对象池复用OnnxTensor,减少创建销毁开销

可观测性

AI推理应该和普通业务逻辑一样可观测。建议埋的点:

  • 推理耗时(分词+推理+后处理的分解耗时)
  • 模型版本号(从config.json读取)
  • 输入输出采样(用于debug和模型监控)
  • GPU利用率(如果用GPU的话)

总结与下一步

ONNX让在Java里跑AI推理变成了现实。不用Python进程,不用微服务拆分,模型直接嵌进JVM里,和业务代码共享同一套基础设施。对企业架构师来说,这意味着AI能力可以平滑接入现有的Java体系,不需要为了AI重建整个技术栈。

落地建议:先从CPU环境跑通全流程,验证分词器和模型版本对齐没问题,再逐步引入GPU。模型和分词器一定要版本化管理,作为完整的部署单元一起发布。推理组件设计成可替换的接口层,方便后续切换模型或调整后处理逻辑。

如果你的团队已经在用Spring Boot,这套模式可以直接复用现有的依赖注入和配置管理。如果还没用Spring,核心的OnnxInferenceEngine和TokenizerService也是Plain Java对象,随时可以集成进去。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注