我们团队的技术栈以JVM为核心,稳定性和可维护性是首要考量。然而,算法团队交付的产出物通常是Python环境下的Keras/TensorFlow模型。过去,我们通过Python Flask或FastAPI将其包装成HTTP服务,再在Kubernetes中部署。这个模式暴露的问题越来越严重:GIL导致的并发性能瓶颈、繁琐的Conda环境管理、以及庞大且构建缓慢的Docker镜像。每一次模型的小幅更新,CI流水线都需要花费十几分钟重新构建包含整个Python环境的镜像,这在需要快速迭代的业务场景中是无法接受的。
痛点明确了:我们需要一个能将AI模型无缝集成到现有Java微服务生态中的方案。这个方案必须是高性能、高可用的,并且其构建和部署流程要与我们现有的云原生CI/CD体系完全兼容。
初步构想与技术选型
我们的目标是构建一个纯JVM的推理服务。它将作为我们事件驱动架构中的一个标准处理单元,从上游Pulsar Topic消费数据,执行模型推理,然后将结果推送到下游的Pulsar Topic。
1. 模型运行环境:TensorFlow for Java
要在JVM上运行Keras模型,最直接的方式是利用TensorFlow官方提供的Java API。Keras模型可以被导出为标准的SavedModel格式,该格式包含了完整的计算图和权重,可以被多种语言的TensorFlow运行时加载。这让我们能够摆脱Python运行时,直接在Java应用中执行原生的高性能推理。
2. 消息与流处理:Apache Pulsar
这是我们内部的既定标准,无需再选。我们将利用Pulsar的以下特性:
- 分区Topic (Partitioned Topics): 水平扩展我们的推理服务实例,每个实例消费一个或多个分区,实现高吞吐。
- 订阅模式 (Subscription Types): 采用
Failover订阅类型来确保消息处理的顺序性(如果业务需要)和高可用,或者Shared类型以最大化并行处理能力。 - 消息确认与重试: 利用Pulsar的
negativeAcknowledge机制实现可靠的错误处理和消息重试。
3. 容器化方案:Google Jib
这是本次实践中一个关键的效率提升点。传统的Dockerfile构建方式有两大弊病:
- 依赖Docker守护进程: 在CI/CD环境中,运行Docker-in-Docker (DinD) 会引入安全风险和性能开销。
- 镜像分层不佳:
COPY target/*.jar app.jar这样的指令,任何代码的改动都会导致整个应用的JAR包层失效,进而需要重新上传一个巨大的层。
Jib通过一个Maven或Gradle插件直接分析项目结构,将依赖(Dependencies)、资源(Resources)和类文件(Classes)分离到不同的镜像层中。这意味着,如果我们只修改了业务代码,只有最小的“Classes”层会被重新构建和推送,极大地加快了CI/CD流程。最重要的是,它完全不需要Docker守护进程。
步骤化实现:从模型加载到无守护进程构建
1. 准备工作:模型与项目结构
首先,算法团队需要提供SavedModel格式的模型。一个典型的Keras模型导出代码如下:
import tensorflow as tf
from tensorflow import keras
# 假设 model 是一个已经训练好的 Keras 模型
# model = ...
# 将其保存为 SavedModel 格式
tf.saved_model.save(model, "./models/fraud_detector/1")
这会创建一个包含saved_model.pb和variables子目录的文件夹。我们将这个fraud_detector文件夹打包到Java服务的资源目录中。
我们的Java项目是一个标准的Maven项目,核心依赖如下:
<!-- pom.xml -->
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<pulsar.version>2.10.2</pulsar.version>
<tensorflow.version>2.9.1</tensorflow.version>
<jib-maven-plugin.version>3.3.1</jib-maven-plugin.version>
</properties>
<dependencies>
<!-- Pulsar 客户端 -->
<dependency>
<groupId>org.apache.pulsar</groupId>
<artifactId>pulsar-client</artifactId>
<version>${pulsar.version}</version>
</dependency>
<!-- TensorFlow Java API -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-core-platform</artifactId>
<version>${tensorflow.version}</version>
</dependency>
<!-- 日志 -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>1.7.36</version>
</dependency>
</dependencies>
2. 核心实现:模型加载与推理引擎
创建一个InferenceService来封装模型的加载和推理逻辑。在生产环境中,模型加载是一个昂贵的操作,应该在服务启动时完成,且只执行一次。
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TFloat32;
import java.io.Closeable;
import java.nio.file.Paths;
public class InferenceService implements Closeable {
private static final Logger log = LoggerFactory.getLogger(InferenceService.class);
private final SavedModelBundle modelBundle;
public InferenceService(String modelPath) {
log.info("Attempting to load model from path: {}", modelPath);
try {
// SavedModelBundle.load() 是线程安全的,可以在多线程环境中共享
this.modelBundle = SavedModelBundle.load(modelPath, "serve");
log.info("Model loaded successfully. SignatureDef: {}",
modelBundle.signature().toString());
} catch (Exception e) {
log.error("Failed to load TensorFlow model from path: " + modelPath, e);
// 在真实项目中,这里应该抛出受检异常或导致应用启动失败
throw new RuntimeException("Model loading failed", e);
}
}
/**
* 执行推理。输入是一个二维浮点数组,代表批处理的特征。
* @param features 输入特征, shape [batch_size, num_features]
* @return 推理结果, shape [batch_size, num_outputs]
*/
public float[][] predict(float[][] features) {
// Session 是推理的入口,但它不是线程安全的。
// 不过,可以从线程安全的 SavedModelBundle 中为每个请求创建一个新的 Session。
// 更好的做法是在每个工作线程中持有一个 Session 实例。
// 这里为了简化,我们在方法内创建和关闭。
try (Session session = modelBundle.session()) {
// 1. 将 Java 原生数组转换为 TensorFlow Tensor
// TFloat32.tensorOf() 是创建Tensor的便捷方法
try (Tensor<TFloat32> inputTensor = TFloat32.tensorOf(NdArrays.ofFloats(Shape.of(features.length, features[0].length), features))) {
// 2. 运行推理
// "serving_default" 是 SavedModel 的默认签名
// "input_1" 是模型输入层的名称,这需要和模型定义者确认
// "dense_1" 是模型输出层的名称
Tensor<?> resultTensor = session.runner()
.feed("serving_default_input_1", inputTensor) // 输入节点名要与模型匹配
.fetch("StatefulPartitionedCall:0") // 输出节点名也要与模型匹配
.run()
.get(0);
// 3. 将结果 Tensor 转换回 Java 原生数组
// 注意结果张量的 shape 和数据类型
if (!(resultTensor instanceof TFloat32)) {
throw new IllegalStateException("Expected TFloat32 tensor as result, but got " + resultTensor.getClass().getName());
}
TFloat32 resultTFloat32 = (TFloat32) resultTensor;
long[] shape = resultTFloat32.shape().asArray();
if (shape.length != 2) {
throw new IllegalStateException("Expected a 2D result tensor, but shape is " + java.util.Arrays.toString(shape));
}
// 创建一个 Java 数组来接收结果
float[][] prediction = new float[(int)shape[0]][(int)shape[1]];
resultTFloat32.copyTo(prediction);
return prediction;
}
} catch (Exception e) {
log.error("Inference execution failed", e);
// 根据业务需求返回null, 空数组, 或抛出异常
return new float[0][0];
}
}
@Override
public void close() {
if (modelBundle != null) {
modelBundle.close();
log.info("TensorFlow model bundle closed.");
}
}
}
这里的坑在于:
- 线程安全:
SavedModelBundle是线程安全的,但它内部的Session不是。最佳实践是在每个处理线程中创建一个Session实例并复用它,而不是像上面示例中那样在每次predict调用时都创建。 - 节点名称:
feed()和fetch()中的节点名称(如serving_default_input_1,StatefulPartitionedCall:0)必须与SavedModel中的签名定义完全匹配。这些信息需要通过saved_model_cli工具或与模型开发者沟通来获取。 - 资源管理:
Tensor和SavedModelBundle都占用了本地内存(甚至可能是GPU显存),必须用try-with-resources或手动close()来确保资源被释放,否则会导致严重的内存泄漏。
3. 集成Pulsar消费者与异步处理
接下来,我们创建一个Pulsar消费者来接收数据,调用InferenceService,然后将结果发送出去。为了提高性能并处理背压,我们采用完全异步的模式。
import org.apache.pulsar.client.api.*;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class PulsarInferenceWorker implements Runnable, AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(PulsarInferenceWorker.class);
private final PulsarClient client;
private final Consumer<byte[]> consumer;
private final Producer<byte[]> producer;
private final InferenceService inferenceService;
// 使用专用的线程池来执行计算密集型的推理任务,避免阻塞Pulsar的IO线程
private final ExecutorService inferenceExecutor;
public PulsarInferenceWorker(String serviceUrl, String inputTopic, String outputTopic, String modelPath) throws PulsarClientException {
this.client = PulsarClient.builder().serviceUrl(serviceUrl).build();
this.consumer = client.newConsumer()
.topic(inputTopic)
.subscriptionName("keras-inference-subscription")
.subscriptionType(SubscriptionType.Shared) // 使用 Shared 模式最大化并行
.receiverQueueSize(100) // 控制内存中的消息缓冲
.subscribe();
this.producer = client.newProducer()
.topic(outputTopic)
.create();
this.inferenceService = new InferenceService(modelPath);
this.inferenceExecutor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
}
@Override
public void run() {
log.info("Pulsar inference worker started. Listening on topic: {}", consumer.getTopic());
while (true) {
try {
// 1. 异步接收消息
CompletableFuture<Message<byte[]>> future = consumer.receiveAsync();
future.thenAcceptAsync(message -> {
try {
// 2. 解析消息并执行推理
// 生产环境中,这里应该是反序列化,例如JSON或Avro
float[][] features = parseMessage(message.getData());
float[][] result = inferenceService.predict(features);
byte[] outputPayload = serializeResult(result);
// 3. 异步发送结果
producer.sendAsync(outputPayload).thenAccept(messageId -> {
log.debug("Successfully produced result for messageId: {}", message.getMessageId());
// 4. 只有当结果成功发送后,才确认原始消息
consumer.acknowledgeAsync(message);
}).exceptionally(ex -> {
log.error("Failed to produce result for messageId: {}", message.getMessageId(), ex);
// 生产失败,不确认消息,Pulsar会自动重传
consumer.negativeAcknowledge(message);
return null;
});
} catch (Exception e) {
log.error("Error processing messageId: {}", message.getMessageId(), e);
// 处理失败,不确认消息,让Pulsar在延迟后重传
consumer.negativeAcknowledge(message);
}
}, inferenceExecutor).exceptionally(ex -> {
log.error("Failed to receive message from Pulsar", ex);
// 发生严重错误,可能需要重启服务
return null;
});
// 等待异步操作完成,这里只是为了演示循环,实际应用中可以有更复杂的逻辑
future.join();
} catch (Exception e) {
log.error("Pulsar consumer loop interrupted", e);
break;
}
}
}
// 单元测试时需要mock这些方法
protected float[][] parseMessage(byte[] data) {
// 伪代码: 将字节数组解析为二维浮点数组
// 在真实项目中,这里会使用Jackson, Gson, or Avro
return new float[][]{{1.0f, 2.0f, 3.0f}};
}
protected byte[] serializeResult(float[][] result) {
// 伪代码: 将结果序列化为字节数组
return java.util.Arrays.deepToString(result).getBytes(StandardCharsets.UTF_8);
}
@Override
public void close() throws Exception {
log.info("Closing Pulsar worker...");
if (inferenceExecutor != null) {
inferenceExecutor.shutdown();
inferenceExecutor.awaitTermination(30, TimeUnit.SECONDS);
}
if (producer != null) producer.close();
if (consumer != null) consumer.close();
if (client != null) client.close();
if (inferenceService != null) inferenceService.close();
log.info("Pulsar worker closed.");
}
}
这个架构展示了事件驱动模型推理的核心流程。
sequenceDiagram
participant PulsarTopicIn as "Input Topic"
participant Worker as "PulsarInferenceWorker"
participant InferenceExecutor as "Inference Thread Pool"
participant TFService as "InferenceService"
participant PulsarTopicOut as "Output Topic"
PulsarTopicIn->>Worker: consumer.receiveAsync()
Worker->>InferenceExecutor: thenAcceptAsync(message, executor)
Note right of Worker: Pulsar I/O 线程被释放,不被阻塞
InferenceExecutor->>InferenceExecutor: parseMessage(message)
InferenceExecutor->>TFService: predict(features)
TFService-->>InferenceExecutor: return result
InferenceExecutor->>InferenceExecutor: serializeResult(result)
InferenceExecutor->>PulsarTopicOut: producer.sendAsync(payload)
PulsarTopicOut-->>InferenceExecutor: CompletableFuture
InferenceExecutor->>PulsarTopicIn: consumer.acknowledgeAsync(message)
Note right of InferenceExecutor: 确认(ack)发生在结果发送成功后
4. 使用Jib实现无守护进程的容器化
最后一步是配置Jib Maven插件。在pom.xml的<build><plugins>部分加入以下配置:
<!-- pom.xml -->
<plugin>
<groupId>com.google.cloud.tools</groupId>
<artifactId>jib-maven-plugin</artifactId>
<version>${jib-maven-plugin.version}</version>
<configuration>
<from>
<!-- 使用一个包含glibc的标准基础镜像 -->
<image>eclipse-temurin:11-jre</image>
</from>
<to>
<!-- 目标镜像仓库地址和名称 -->
<image>my-registry.example.com/ml-inference/keras-pulsar-service</image>
<tags>
<tag>${project.version}</tag>
<tag>latest</tag>
</tags>
<!-- 配置凭据,通常在CI环境中通过环境变量或配置文件提供 -->
<auth>
<username>${env.REGISTRY_USER}</username>
<password>${env.REGISTRY_PASSWORD}</password>
</auth>
</to>
<container>
<!-- 设置JVM参数,对于内存敏感的模型服务至关重要 -->
<jvmFlags>
<jvmFlag>-Xms512m</jvmFlag>
<jvmFlag>-Xmx2048m</jvmFlag>
<jvmFlag>-XX:+UseG1GC</jvmFlag>
</jvmFlags>
<mainClass>com.yourcompany.MainApplication</mainClass>
<ports>
<!-- 如果有健康检查端点,可以暴露端口 -->
<!-- <port>8080</port> -->
</ports>
<creationTime>USE_CURRENT_TIMESTAMP</creationTime>
</container>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>build</goal> <!-- 'build' 会构建并推送到远程仓库 -->
</goals>
</execution>
</executions>
</plugin>
现在,在CI流水线中,我们不再需要docker login和docker build。一个简单的命令就完成了所有工作:
mvn compile package -DREGISTRY_USER=... -DREGISTRY_PASSWORD=...
jib-maven-plugin会在package阶段自动执行。它会:
- 拉取基础镜像
eclipse-temurin:11-jre。 - 将项目依赖(JARs)作为一个层。
- 将项目资源(包括
src/main/resources下的模型文件)作为另一个层。 - 将编译后的
*.class文件作为最顶层。 - 将这些层组合成一个符合OCI标准的镜像,并直接推送到
my-registry.example.com。
当我们只修改了PulsarInferenceWorker.java中的业务逻辑时,只有包含.class文件的那个几KB大小的层会被重新构建和推送,CI/CD的效率得到了质的提升。
遗留问题与未来迭代路径
这个方案解决了我们最初的痛点,但它并非终点。在生产环境中,我们很快会遇到新的挑战。
当前架构最大的局限在于模型更新。每次模型迭代都需要重新构建和部署整个服务。一个更优雅的方案是实现模型的热加载。这可以设计成一个独立的控制流:运维或MLOps平台向一个专用的Pulsar控制Topic发送一条消息,其中包含新模型的路径(例如S3 URI)。服务实例消费到这条消息后,会安全地下载新模型,并在原子操作中替换掉内存中的旧模型实例。这需要仔细处理并发访问和资源释放,确保在切换过程中服务不中断。
其次,当前的实现是纯CPU推理。对于需要更高吞吐量的场景,支持GPU是必须的。这需要将Jib的基础镜像换成带有CUDA库的NVIDIA官方镜像(如nvidia/cuda:11.4.2-base-ubuntu20.04),并在Kubernetes部署文件中通过nodeSelector和resources.limits来请求GPU资源。这为基础设施和部署配置带来了额外的复杂度。
最后,为了极致的性能,可以引入动态批处理(Dynamic Batching)。服务不再是来一条消息处理一条,而是在一个极短的时间窗口内(如10毫秒)聚合多条消息,将它们的特征数据合并成一个更大的Tensor批次,一次性送入模型进行推理。这能显著提高硬件利用率,尤其是GPU。然而,它以牺牲少量延迟为代价,并且需要更复杂的逻辑来处理批次内的部分失败和结果分发。