构建基于Pulsar与Jib的Keras模型实时推理服务的生产实践


我们团队的技术栈以JVM为核心,稳定性和可维护性是首要考量。然而,算法团队交付的产出物通常是Python环境下的Keras/TensorFlow模型。过去,我们通过Python Flask或FastAPI将其包装成HTTP服务,再在Kubernetes中部署。这个模式暴露的问题越来越严重:GIL导致的并发性能瓶颈、繁琐的Conda环境管理、以及庞大且构建缓慢的Docker镜像。每一次模型的小幅更新,CI流水线都需要花费十几分钟重新构建包含整个Python环境的镜像,这在需要快速迭代的业务场景中是无法接受的。

痛点明确了:我们需要一个能将AI模型无缝集成到现有Java微服务生态中的方案。这个方案必须是高性能、高可用的,并且其构建和部署流程要与我们现有的云原生CI/CD体系完全兼容。

初步构想与技术选型

我们的目标是构建一个纯JVM的推理服务。它将作为我们事件驱动架构中的一个标准处理单元,从上游Pulsar Topic消费数据,执行模型推理,然后将结果推送到下游的Pulsar Topic。

1. 模型运行环境:TensorFlow for Java

要在JVM上运行Keras模型,最直接的方式是利用TensorFlow官方提供的Java API。Keras模型可以被导出为标准的SavedModel格式,该格式包含了完整的计算图和权重,可以被多种语言的TensorFlow运行时加载。这让我们能够摆脱Python运行时,直接在Java应用中执行原生的高性能推理。

2. 消息与流处理:Apache Pulsar

这是我们内部的既定标准,无需再选。我们将利用Pulsar的以下特性:

  • 分区Topic (Partitioned Topics): 水平扩展我们的推理服务实例,每个实例消费一个或多个分区,实现高吞吐。
  • 订阅模式 (Subscription Types): 采用Failover订阅类型来确保消息处理的顺序性(如果业务需要)和高可用,或者Shared类型以最大化并行处理能力。
  • 消息确认与重试: 利用Pulsar的negativeAcknowledge机制实现可靠的错误处理和消息重试。

3. 容器化方案:Google Jib

这是本次实践中一个关键的效率提升点。传统的Dockerfile构建方式有两大弊病:

  • 依赖Docker守护进程: 在CI/CD环境中,运行Docker-in-Docker (DinD) 会引入安全风险和性能开销。
  • 镜像分层不佳: COPY target/*.jar app.jar这样的指令,任何代码的改动都会导致整个应用的JAR包层失效,进而需要重新上传一个巨大的层。

Jib通过一个Maven或Gradle插件直接分析项目结构,将依赖(Dependencies)、资源(Resources)和类文件(Classes)分离到不同的镜像层中。这意味着,如果我们只修改了业务代码,只有最小的“Classes”层会被重新构建和推送,极大地加快了CI/CD流程。最重要的是,它完全不需要Docker守护进程。

步骤化实现:从模型加载到无守护进程构建

1. 准备工作:模型与项目结构

首先,算法团队需要提供SavedModel格式的模型。一个典型的Keras模型导出代码如下:

import tensorflow as tf
from tensorflow import keras

# 假设 model 是一个已经训练好的 Keras 模型
# model = ...

# 将其保存为 SavedModel 格式
tf.saved_model.save(model, "./models/fraud_detector/1")

这会创建一个包含saved_model.pbvariables子目录的文件夹。我们将这个fraud_detector文件夹打包到Java服务的资源目录中。

我们的Java项目是一个标准的Maven项目,核心依赖如下:

<!-- pom.xml -->
<properties>
    <maven.compiler.source>11</maven.compiler.source>
    <maven.compiler.target>11</maven.compiler.target>
    <pulsar.version>2.10.2</pulsar.version>
    <tensorflow.version>2.9.1</tensorflow.version>
    <jib-maven-plugin.version>3.3.1</jib-maven-plugin.version>
</properties>

<dependencies>
    <!-- Pulsar 客户端 -->
    <dependency>
        <groupId>org.apache.pulsar</groupId>
        <artifactId>pulsar-client</artifactId>
        <version>${pulsar.version}</version>
    </dependency>

    <!-- TensorFlow Java API -->
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow-core-platform</artifactId>
        <version>${tensorflow.version}</version>
    </dependency>

    <!-- 日志 -->
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-simple</artifactId>
        <version>1.7.36</version>
    </dependency>
</dependencies>

2. 核心实现:模型加载与推理引擎

创建一个InferenceService来封装模型的加载和推理逻辑。在生产环境中,模型加载是一个昂贵的操作,应该在服务启动时完成,且只执行一次。

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TFloat32;

import java.io.Closeable;
import java.nio.file.Paths;

public class InferenceService implements Closeable {

    private static final Logger log = LoggerFactory.getLogger(InferenceService.class);

    private final SavedModelBundle modelBundle;

    public InferenceService(String modelPath) {
        log.info("Attempting to load model from path: {}", modelPath);
        try {
            // SavedModelBundle.load() 是线程安全的,可以在多线程环境中共享
            this.modelBundle = SavedModelBundle.load(modelPath, "serve");
            log.info("Model loaded successfully. SignatureDef: {}",
                    modelBundle.signature().toString());
        } catch (Exception e) {
            log.error("Failed to load TensorFlow model from path: " + modelPath, e);
            // 在真实项目中,这里应该抛出受检异常或导致应用启动失败
            throw new RuntimeException("Model loading failed", e);
        }
    }

    /**
     * 执行推理。输入是一个二维浮点数组,代表批处理的特征。
     * @param features 输入特征, shape [batch_size, num_features]
     * @return 推理结果, shape [batch_size, num_outputs]
     */
    public float[][] predict(float[][] features) {
        // Session 是推理的入口,但它不是线程安全的。
        // 不过,可以从线程安全的 SavedModelBundle 中为每个请求创建一个新的 Session。
        // 更好的做法是在每个工作线程中持有一个 Session 实例。
        // 这里为了简化,我们在方法内创建和关闭。
        try (Session session = modelBundle.session()) {
            
            // 1. 将 Java 原生数组转换为 TensorFlow Tensor
            // TFloat32.tensorOf() 是创建Tensor的便捷方法
            try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(NdArrays.ofFloats(Shape.of(features.length, features[0].length), features))) {

                // 2. 运行推理
                // "serving_default" 是 SavedModel 的默认签名
                // "input_1" 是模型输入层的名称,这需要和模型定义者确认
                // "dense_1" 是模型输出层的名称
                Tensor<?> resultTensor = session.runner()
                        .feed("serving_default_input_1", inputTensor) // 输入节点名要与模型匹配
                        .fetch("StatefulPartitionedCall:0")     // 输出节点名也要与模型匹配
                        .run()
                        .get(0);

                // 3. 将结果 Tensor 转换回 Java 原生数组
                // 注意结果张量的 shape 和数据类型
                if (!(resultTensor instanceof TFloat32)) {
                    throw new IllegalStateException("Expected TFloat32 tensor as result, but got " + resultTensor.getClass().getName());
                }
                
                TFloat32 resultTFloat32 = (TFloat32) resultTensor;
                long[] shape = resultTFloat32.shape().asArray();
                if (shape.length != 2) {
                    throw new IllegalStateException("Expected a 2D result tensor, but shape is " + java.util.Arrays.toString(shape));
                }

                // 创建一个 Java 数组来接收结果
                float[][] prediction = new float[(int)shape[0]][(int)shape[1]];
                resultTFloat32.copyTo(prediction);
                
                return prediction;
            }
        } catch (Exception e) {
            log.error("Inference execution failed", e);
            // 根据业务需求返回null, 空数组, 或抛出异常
            return new float[0][0];
        }
    }

    @Override
    public void close() {
        if (modelBundle != null) {
            modelBundle.close();
            log.info("TensorFlow model bundle closed.");
        }
    }
}

这里的坑在于:

  1. 线程安全: SavedModelBundle是线程安全的,但它内部的Session不是。最佳实践是在每个处理线程中创建一个Session实例并复用它,而不是像上面示例中那样在每次predict调用时都创建。
  2. 节点名称: feed()fetch()中的节点名称(如serving_default_input_1, StatefulPartitionedCall:0)必须与SavedModel中的签名定义完全匹配。这些信息需要通过saved_model_cli工具或与模型开发者沟通来获取。
  3. 资源管理: TensorSavedModelBundle都占用了本地内存(甚至可能是GPU显存),必须用try-with-resources或手动close()来确保资源被释放,否则会导致严重的内存泄漏。

3. 集成Pulsar消费者与异步处理

接下来,我们创建一个Pulsar消费者来接收数据,调用InferenceService,然后将结果发送出去。为了提高性能并处理背压,我们采用完全异步的模式。

import org.apache.pulsar.client.api.*;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class PulsarInferenceWorker implements Runnable, AutoCloseable {

    private static final Logger log = LoggerFactory.getLogger(PulsarInferenceWorker.class);

    private final PulsarClient client;
    private final Consumer<byte[]> consumer;
    private final Producer<byte[]> producer;
    private final InferenceService inferenceService;
    // 使用专用的线程池来执行计算密集型的推理任务,避免阻塞Pulsar的IO线程
    private final ExecutorService inferenceExecutor;

    public PulsarInferenceWorker(String serviceUrl, String inputTopic, String outputTopic, String modelPath) throws PulsarClientException {
        this.client = PulsarClient.builder().serviceUrl(serviceUrl).build();
        this.consumer = client.newConsumer()
                .topic(inputTopic)
                .subscriptionName("keras-inference-subscription")
                .subscriptionType(SubscriptionType.Shared) // 使用 Shared 模式最大化并行
                .receiverQueueSize(100) // 控制内存中的消息缓冲
                .subscribe();
        this.producer = client.newProducer()
                .topic(outputTopic)
                .create();
        this.inferenceService = new InferenceService(modelPath);
        this.inferenceExecutor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
    }

    @Override
    public void run() {
        log.info("Pulsar inference worker started. Listening on topic: {}", consumer.getTopic());
        while (true) {
            try {
                // 1. 异步接收消息
                CompletableFuture<Message<byte[]>> future = consumer.receiveAsync();
                
                future.thenAcceptAsync(message -> {
                    try {
                        // 2. 解析消息并执行推理
                        // 生产环境中,这里应该是反序列化,例如JSON或Avro
                        float[][] features = parseMessage(message.getData());
                        float[][] result = inferenceService.predict(features);
                        byte[] outputPayload = serializeResult(result);

                        // 3. 异步发送结果
                        producer.sendAsync(outputPayload).thenAccept(messageId -> {
                            log.debug("Successfully produced result for messageId: {}", message.getMessageId());
                            // 4. 只有当结果成功发送后,才确认原始消息
                            consumer.acknowledgeAsync(message);
                        }).exceptionally(ex -> {
                            log.error("Failed to produce result for messageId: {}", message.getMessageId(), ex);
                            // 生产失败,不确认消息,Pulsar会自动重传
                            consumer.negativeAcknowledge(message);
                            return null;
                        });

                    } catch (Exception e) {
                        log.error("Error processing messageId: {}", message.getMessageId(), e);
                        // 处理失败,不确认消息,让Pulsar在延迟后重传
                        consumer.negativeAcknowledge(message);
                    }
                }, inferenceExecutor).exceptionally(ex -> {
                    log.error("Failed to receive message from Pulsar", ex);
                    // 发生严重错误,可能需要重启服务
                    return null;
                });
                
                // 等待异步操作完成,这里只是为了演示循环,实际应用中可以有更复杂的逻辑
                future.join();

            } catch (Exception e) {
                log.error("Pulsar consumer loop interrupted", e);
                break;
            }
        }
    }
    
    // 单元测试时需要mock这些方法
    protected float[][] parseMessage(byte[] data) {
        // 伪代码: 将字节数组解析为二维浮点数组
        // 在真实项目中,这里会使用Jackson, Gson, or Avro
        return new float[][]{{1.0f, 2.0f, 3.0f}};
    }

    protected byte[] serializeResult(float[][] result) {
        // 伪代码: 将结果序列化为字节数组
        return java.util.Arrays.deepToString(result).getBytes(StandardCharsets.UTF_8);
    }


    @Override
    public void close() throws Exception {
        log.info("Closing Pulsar worker...");
        if (inferenceExecutor != null) {
            inferenceExecutor.shutdown();
            inferenceExecutor.awaitTermination(30, TimeUnit.SECONDS);
        }
        if (producer != null) producer.close();
        if (consumer != null) consumer.close();
        if (client != null) client.close();
        if (inferenceService != null) inferenceService.close();
        log.info("Pulsar worker closed.");
    }
}

这个架构展示了事件驱动模型推理的核心流程。

sequenceDiagram
    participant PulsarTopicIn as "Input Topic"
    participant Worker as "PulsarInferenceWorker"
    participant InferenceExecutor as "Inference Thread Pool"
    participant TFService as "InferenceService"
    participant PulsarTopicOut as "Output Topic"

    PulsarTopicIn->>Worker: consumer.receiveAsync()
    Worker->>InferenceExecutor: thenAcceptAsync(message, executor)
    Note right of Worker: Pulsar I/O 线程被释放,不被阻塞
    InferenceExecutor->>InferenceExecutor: parseMessage(message)
    InferenceExecutor->>TFService: predict(features)
    TFService-->>InferenceExecutor: return result
    InferenceExecutor->>InferenceExecutor: serializeResult(result)
    InferenceExecutor->>PulsarTopicOut: producer.sendAsync(payload)
    PulsarTopicOut-->>InferenceExecutor: CompletableFuture
    InferenceExecutor->>PulsarTopicIn: consumer.acknowledgeAsync(message)
    Note right of InferenceExecutor: 确认(ack)发生在结果发送成功后

4. 使用Jib实现无守护进程的容器化

最后一步是配置Jib Maven插件。在pom.xml<build><plugins>部分加入以下配置:

<!-- pom.xml -->
<plugin>
    <groupId>com.google.cloud.tools</groupId>
    <artifactId>jib-maven-plugin</artifactId>
    <version>${jib-maven-plugin.version}</version>
    <configuration>
        <from>
            <!-- 使用一个包含glibc的标准基础镜像 -->
            <image>eclipse-temurin:11-jre</image>
        </from>
        <to>
            <!-- 目标镜像仓库地址和名称 -->
            <image>my-registry.example.com/ml-inference/keras-pulsar-service</image>
            <tags>
                <tag>${project.version}</tag>
                <tag>latest</tag>
            </tags>
            <!-- 配置凭据,通常在CI环境中通过环境变量或配置文件提供 -->
            <auth>
                <username>${env.REGISTRY_USER}</username>
                <password>${env.REGISTRY_PASSWORD}</password>
            </auth>
        </to>
        <container>
            <!-- 设置JVM参数,对于内存敏感的模型服务至关重要 -->
            <jvmFlags>
                <jvmFlag>-Xms512m</jvmFlag>
                <jvmFlag>-Xmx2048m</jvmFlag>
                <jvmFlag>-XX:+UseG1GC</jvmFlag>
            </jvmFlags>
            <mainClass>com.yourcompany.MainApplication</mainClass>
            <ports>
                <!-- 如果有健康检查端点,可以暴露端口 -->
                <!-- <port>8080</port> -->
            </ports>
            <creationTime>USE_CURRENT_TIMESTAMP</creationTime>
        </container>
    </configuration>
    <executions>
        <execution>
            <phase>package</phase>
            <goals>
                <goal>build</goal> <!-- 'build' 会构建并推送到远程仓库 -->
            </goals>
        </execution>
    </executions>
</plugin>

现在,在CI流水线中,我们不再需要docker logindocker build。一个简单的命令就完成了所有工作:

mvn compile package -DREGISTRY_USER=... -DREGISTRY_PASSWORD=...

jib-maven-plugin会在package阶段自动执行。它会:

  1. 拉取基础镜像eclipse-temurin:11-jre
  2. 将项目依赖(JARs)作为一个层。
  3. 将项目资源(包括src/main/resources下的模型文件)作为另一个层。
  4. 将编译后的*.class文件作为最顶层。
  5. 将这些层组合成一个符合OCI标准的镜像,并直接推送到my-registry.example.com

当我们只修改了PulsarInferenceWorker.java中的业务逻辑时,只有包含.class文件的那个几KB大小的层会被重新构建和推送,CI/CD的效率得到了质的提升。

遗留问题与未来迭代路径

这个方案解决了我们最初的痛点,但它并非终点。在生产环境中,我们很快会遇到新的挑战。

当前架构最大的局限在于模型更新。每次模型迭代都需要重新构建和部署整个服务。一个更优雅的方案是实现模型的热加载。这可以设计成一个独立的控制流:运维或MLOps平台向一个专用的Pulsar控制Topic发送一条消息,其中包含新模型的路径(例如S3 URI)。服务实例消费到这条消息后,会安全地下载新模型,并在原子操作中替换掉内存中的旧模型实例。这需要仔细处理并发访问和资源释放,确保在切换过程中服务不中断。

其次,当前的实现是纯CPU推理。对于需要更高吞吐量的场景,支持GPU是必须的。这需要将Jib的基础镜像换成带有CUDA库的NVIDIA官方镜像(如nvidia/cuda:11.4.2-base-ubuntu20.04),并在Kubernetes部署文件中通过nodeSelectorresources.limits来请求GPU资源。这为基础设施和部署配置带来了额外的复杂度。

最后,为了极致的性能,可以引入动态批处理(Dynamic Batching)。服务不再是来一条消息处理一条,而是在一个极短的时间窗口内(如10毫秒)聚合多条消息,将它们的特征数据合并成一个更大的Tensor批次,一次性送入模型进行推理。这能显著提高硬件利用率,尤其是GPU。然而,它以牺牲少量延迟为代价,并且需要更复杂的逻辑来处理批次内的部分失败和结果分发。


  目录