基于Tornado与ScyllaDB构建支持动态请求批处理的机器学习推理网关


一个看似标准的任务摆在了面前:将一个已经训练好的 Scikit-learn 模型部署为在线服务。模型本身不复杂,一个用于实时风险评估的梯度提升树,输入几十个特征,输出一个风险评分。最初使用 Flask 简单包装一下,app.run() 之后,单次请求的响应时间在 15ms 左右,看起来还不错。但压力测试的结果却是一场灾难。并发量一旦超过 50,响应延迟就急剧攀升到数百毫秒,CPU 占用率轻易触顶,整个服务吞吐量完全不达标。

问题根源很清晰:Python 的 GIL 和 Scikit-learn 模型 predict 方法的同步阻塞特性。在高并发场景下,每一个进来的请求都会在 WSGI 的工作线程里等待模型计算完成,而模型计算是 CPU 密集型的。这导致大量线程上下文切换和 GIL 争抢,I/O 和 CPU 的效率都极其低下。我们需要一个完全不同的架构。

初步构想是切换到异步 I/O 模型,这能有效解决 I/O 等待问题。Tornado 是一个成熟的选择。但仅仅切换到 Tornado 并不能解决核心矛盾:Tornado 的事件循环是单线程的,任何同步阻塞的操作,尤其是我们这个 CPU 密集的 predict 调用,都会冻结整个事件循环,导致所有其他请求被挂起。这比多线程阻塞模型更糟糕。

真正的解决方案必须同时解决 I/O 并发和 CPU 计算效率。由此,一个核心架构思路浮出水面:**动态请求批处理 (Dynamic Request Batching)**。我们不再让每个请求都独立触发一次模型计算,而是在服务内部设置一个短暂的缓冲窗口。在这个窗口期(比如 5-10 毫秒)内到达的所有请求,将被合并成一个批次,然后一次性送入模型进行批量预测。这样做的好处是显而易见的:

  1. 摊销开销:批量调用 predict 的开销远小于多次独立调用。
  2. 提升并行度:可以更好地利用现代 CPU 的 SIMD 指令。
  3. 解耦 I/O 与计算:Tornado 的主事件循环只负责接收请求、放入批处理队列,然后异步等待结果,完全不被计算阻塞。实际的计算任务则被抛到一个独立的线程池中执行。

基于此,技术选型也变得明确:

  • Web框架: Tornado。利用其原生的异步能力和对 asyncio 的完美支持,作为请求的接入和调度层。
  • 特征存储: ScyllaDB。推理前需要从数据库中实时拉取用户特征。这个环节的延迟必须极低。ScyllaDB 作为 Cassandra 的 C++ 高性能重写,提供了 P99 在个位数毫秒的读延迟,并且其分片架构天然支持高并发查询,非常适合这个场景。
  • 模型: Scikit-learn。这是既定条件,我们需要在架构上弥补它的部署短板。
  • 网关: 生产环境中,服务前置一个 API Gateway (如 Kong, APISIX) 来处理认证、限流、日志等通用逻辑,本文则聚焦于推理服务本身的设计与实现。

整个系统的请求生命周期将如下所示:

sequenceDiagram
    participant Client
    participant API Gateway
    participant Tornado Service
    participant RequestBatcher
    participant FeatureStore (ScyllaDB)
    participant ThreadPool (Scikit-learn)

    Client->>+API Gateway: POST /predict (data_1)
    API Gateway->>+Tornado Service: Forward Request
    Tornado Service->>+RequestBatcher: Add request_1 (with future_1) to queue
    Client->>+API Gateway: POST /predict (data_2)
    API Gateway->>+Tornado Service: Forward Request
    Tornado Service->>+RequestBatcher: Add request_2 (with future_2) to queue
    
    Note over RequestBatcher: Batch window (e.g., 10ms) timer starts or max batch size reached.
    
    RequestBatcher->>RequestBatcher: Form batch [req_1, req_2]
    RequestBatcher->>+FeatureStore (ScyllaDB): ASYNC query features for batch
    FeatureStore (ScyllaDB)-->>-RequestBatcher: Return features for [req_1, req_2]
    
    RequestBatcher->>+ThreadPool (Scikit-learn): Offload batch predict
    ThreadPool (Scikit-learn)-->>-RequestBatcher: Return predictions [pred_1, pred_2]
    
    RequestBatcher->>Tornado Service: Resolve future_1 with pred_1
    RequestBatcher->>Tornado Service: Resolve future_2 with pred_2
    
    Tornado Service-->>-API Gateway: Respond for request_1
    API Gateway-->>-Client: Return prediction_1
    Tornado Service-->>-API Gateway: Respond for request_2
    API Gateway-->>-Client: Return prediction_2

第一步: 环境与数据模型准备

我们首先需要一个 ScyllaDB 实例,并定义好特征表的结构。在真实项目中,特征可能非常复杂,这里我们简化为一个用户画像表。

ScyllaDB 表结构 (CQL):

CREATE KEYSPACE IF NOT EXISTS model_features WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };

USE model_features;

CREATE TABLE IF NOT EXISTS user_profiles (
    user_id text PRIMARY KEY,
    feature_1 float,
    feature_2 float,
    feature_3 float,
    last_updated timestamp
);

-- 插入一些样本数据
INSERT INTO user_profiles (user_id, feature_1, feature_2, feature_3, last_updated) VALUES ('user-001', 0.5, 1.2, -0.8, toTimestamp(now()));
INSERT INTO user_profiles (user_id, feature_1, feature_2, feature_3, last_updated) VALUES ('user-002', -1.1, 0.3, 2.5, toTimestamp(now()));

同时,我们需要一个预先训练并保存的 Scikit-learn 模型。这里我们用一个简单的 LogisticRegression 模型作为示例,并用 pickle 保存。

模型训练与保存脚本 (train_model.py):

import numpy as np
from sklearn.linear_model import LogisticRegression
import pickle

# 模拟训练数据
X_train = np.random.rand(100, 3) * 10 - 5
y_train = (X_train[:, 0] + X_train[:, 1] * 0.5 - X_train[:, 2] > 0).astype(int)

# 训练一个简单的模型
model = LogisticRegression()
model.fit(X_train, y_train)

# 保存模型
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

print("Model trained and saved to model.pkl")

执行 python train_model.py 会生成 model.pkl 文件。

第二步: 核心组件实现

项目结构如下:

.
├── config.py           # 配置文件
├── feature_store.py    # ScyllaDB 数据访问层
├── model_loader.py     # 模型加载器
├── request_batcher.py  # 核心:动态请求批处理模块
├── server.py           # Tornado 应用主文件
├── model.pkl           # 训练好的模型
└── train_model.py      # 模型训练脚本 (仅用于准备)

config.py - 集中管理配置
将所有可调参数放在这里,便于维护。

# config.py
import logging

# Server Configuration
SERVER_PORT = 8888
SERVER_WORKERS = 4 # 生产环境建议与 CPU 核心数一致

# ScyllaDB / Cassandra Configuration
SCYLLA_HOSTS = ['127.0.0.1']
SCYLLA_PORT = 9042
SCYLLA_KEYSPACE = 'model_features'

# Model Configuration
MODEL_PATH = 'model.pkl'

# Batcher Configuration
# 最大批处理大小
MAX_BATCH_SIZE = 64
# 批处理窗口超时时间 (毫秒),这是延迟和吞吐量之间的关键权衡点
BATCH_TIMEOUT_MS = 10 

# Thread Pool for CPU-bound tasks
# 线程池大小,通常设置为 CPU 核心数
WORKER_THREADS = 4

# Logging Configuration
LOGGING_LEVEL = logging.INFO

feature_store.py - 异步特征获取
这一层必须是完全异步的。cassandra-driver 提供了基于 asyncio 的异步执行接口。

# feature_store.py
import asyncio
from cassandra.cluster import Cluster
from cassandra.query import SimpleStatement
from typing import List, Dict, Any

import config

class FeatureStore:
    def __init__(self):
        self.cluster = Cluster(config.SCYLLA_HOSTS, port=config.SCYLLA_PORT)
        self.session = None

    async def connect(self):
        """异步连接到数据库"""
        loop = asyncio.get_running_loop()
        # session.connect 是阻塞的, 在 executor 中运行
        self.session = await loop.run_in_executor(None, self.cluster.connect, config.SCYLLA_KEYSPACE)
        print("Successfully connected to ScyllaDB.")
        # 准备查询语句可以提升性能
        self.prepared_statement = self.session.prepare("SELECT feature_1, feature_2, feature_3 FROM user_profiles WHERE user_id = ?")

    async def disconnect(self):
        """关闭连接"""
        if self.cluster:
            await asyncio.get_running_loop().run_in_executor(None, self.cluster.shutdown)
            print("Disconnected from ScyllaDB.")

    async def get_features_batch(self, user_ids: List[str]) -> Dict[str, Any]:
        """
        异步批量获取特征。这是性能关键路径。
        使用 asyncio.gather 并发执行所有查询。
        """
        if not self.session:
            raise ConnectionError("Not connected to the database.")

        async def fetch_one(user_id):
            # 使用 driver 的 execute_async 方法
            future = self.session.execute_async(self.prepared_statement, (user_id,))
            try:
                rows = await asyncio.wrap_future(future)
                row = rows.one()
                if row:
                    return user_id, [row.feature_1, row.feature_2, row.feature_3]
                else:
                    # 在真实项目中,需要定义好特征缺失的处理策略
                    # 这里我们用默认值填充
                    return user_id, [0.0, 0.0, 0.0]
            except Exception as e:
                # 必须处理单个查询失败的情况,避免整个批次失败
                print(f"Error fetching features for {user_id}: {e}")
                return user_id, [0.0, 0.0, 0.0]

        tasks = [fetch_one(user_id) for user_id in user_ids]
        results = await asyncio.gather(*tasks)
        
        return dict(results)

这里的坑在于,即使是 execute_async,其返回的 ResponseFuture 也不是原生的 asyncio.Future。需要使用 asyncio.wrap_future 进行转换,才能在 async/await 语法中使用。批量获取时使用 asyncio.gather 是最大化 I/O 并发的关键。

model_loader.py - 模型加载
一个简单的封装,确保模型在服务启动时加载一次,而不是在每次请求时加载。

# model_loader.py
import pickle
from sklearn.base import BaseEstimator
import config

class ModelLoader:
    def __init__(self, model_path: str):
        self._model_path = model_path
        self._model = None

    def load(self):
        """加载模型到内存"""
        try:
            with open(self._model_path, 'rb') as f:
                self._model = pickle.load(f)
            print(f"Model loaded successfully from {self._model_path}")
        except FileNotFoundError:
            raise RuntimeError(f"Model file not found at {self._model_path}")
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {e}")

    @property
    def model(self) -> BaseEstimator:
        if self._model is None:
            raise RuntimeError("Model is not loaded. Call load() first.")
        return self._model

# 创建全局单例
model_loader = ModelLoader(config.MODEL_PATH)

request_batcher.py - 系统的“心脏”
这是整个架构最核心的部分。它负责收集请求、触发批处理、分发结果。

# request_batcher.py
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any
import numpy as np

import config
from feature_store import FeatureStore
from model_loader import ModelLoader

class PredictionRequest:
    """封装单个预测请求及其用于接收结果的 Future"""
    def __init__(self, user_id: str, request_data: Dict):
        self.user_id = user_id
        self.request_data = request_data # 保留原始请求数据,可能包含其他信息
        self.future = asyncio.Future()

class RequestBatcher:
    def __init__(self, model_loader: ModelLoader, feature_store: FeatureStore):
        self._model_loader = model_loader
        self._feature_store = feature_store
        self._queue = asyncio.Queue()
        self._executor = ThreadPoolExecutor(max_workers=config.WORKER_THREADS)
        self._background_task = None

    async def start(self):
        """启动后台批处理循环任务"""
        self._background_task = asyncio.create_task(self._batching_loop())
        print("Request batcher started.")

    async def stop(self):
        """停止后台任务"""
        if self._background_task:
            self._background_task.cancel()
            try:
                await self._background_task
            except asyncio.CancelledError:
                pass
        self._executor.shutdown(wait=True)
        print("Request batcher stopped.")

    async def submit_request(self, request: PredictionRequest):
        """外部调用方(Tornado Handler)提交请求"""
        await self._queue.put(request)
        return await request.future

    async def _batching_loop(self):
        """
        核心循环:
        1. 持续从队列中拉取请求,形成批次。
        2. 批次形成的条件是:达到最大数量 或 超过等待超时。
        3. 形成批次后,异步处理它。
        """
        while True:
            batch: List[PredictionRequest] = []
            try:
                # 等待第一个请求,设置超时
                timeout = config.BATCH_TIMEOUT_MS / 1000.0
                first_request = await asyncio.wait_for(self._queue.get(), timeout)
                batch.append(first_request)
                
                # 第一个请求到达后,快速拉取队列中已有的其他请求,直到达到批次上限
                start_time = time.monotonic()
                while len(batch) < config.MAX_BATCH_SIZE:
                    # 计算剩余的等待时间
                    remaining_time = timeout - (time.monotonic() - start_time)
                    if remaining_time <= 0:
                        break
                    try:
                        # 使用 non-blocking get_nowait 或短暂的超时
                        req = self._queue.get_nowait()
                        batch.append(req)
                    except asyncio.QueueEmpty:
                        break # 队列空了,立即处理当前批次
                
                if batch:
                    # 创建一个新任务来处理批次,不阻塞主循环
                    asyncio.create_task(self._process_batch(batch))

            except asyncio.TimeoutError:
                # 超时没有收到任何请求,继续循环
                continue
            except asyncio.CancelledError:
                print("Batching loop cancelled.")
                break
            except Exception as e:
                print(f"Error in batching loop: {e}")
                await asyncio.sleep(1) # 发生错误时稍作等待,防止CPU空转

    async def _process_batch(self, batch: List[PredictionRequest]):
        """处理一个完整的批次"""
        start_time = time.time()
        user_ids = [req.user_id for req in batch]
        
        try:
            # 1. 异步批量获取特征
            features_dict = await self._feature_store.get_features_batch(user_ids)
            
            # 按照批次原始顺序准备模型输入
            model_input = np.array([features_dict[uid] for uid in user_ids])
            
            # 2. 将 CPU 密集的模型预测任务扔到线程池
            loop = asyncio.get_running_loop()
            predictions = await loop.run_in_executor(
                self._executor, self._model_loader.model.predict_proba, model_input
            )

            # 3. 将结果分发回每个请求的 Future
            for i, req in enumerate(batch):
                # predict_proba 返回每个类别的概率,我们取类别1的概率
                result = {"user_id": req.user_id, "score": float(predictions[i][1])}
                req.future.set_result(result)

        except Exception as e:
            # 如果批处理过程中任何环节出错,需要通知所有等待的请求
            print(f"Error processing batch: {e}")
            for req in batch:
                if not req.future.done():
                    req.future.set_exception(e)
        
        latency = (time.time() - start_time) * 1000
        print(f"Processed batch of {len(batch)} requests in {latency:.2f} ms.")

这里的实现细节是魔鬼。_batching_loop 的逻辑至关重要,它决定了批次的动态形成。我们不能简单地循环 _queue.get(),否则会一直等待直到队列满。正确的做法是:在收到第一个请求后,启动一个非常短的计时器,在这个计时器内尽可能多地从队列中拉取请求。这保证了低流量时请求不会被过度延迟,高流量时又能形成大批次。

第三步: 组装 Tornado 服务器

现在,我们将所有组件整合到 server.py 中。

# server.py
import asyncio
import json
from tornado.web import Application, RequestHandler, HTTPError
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer

import config
from feature_store import FeatureStore
from model_loader import ModelLoader
from request_batcher import RequestBatcher, PredictionRequest

class PredictionHandler(RequestHandler):
    
    def initialize(self, batcher: RequestBatcher):
        # 通过 initialize 注入依赖
        self.batcher = batcher

    async def post(self):
        """处理预测请求的端点"""
        try:
            body = json.loads(self.request.body)
            user_id = body.get('user_id')
            if not user_id:
                raise HTTPError(400, reason="user_id is required.")

            # 创建请求对象
            pred_request = PredictionRequest(user_id=user_id, request_data=body)
            
            # 提交到批处理器并异步等待结果
            # 这里是 Tornado 事件循环与批处理系统的交汇点
            # Handler 会在此处 'await',将控制权交还给事件循环,直到 future 完成
            result = await self.batcher.submit_request(pred_request)

            self.set_header("Content-Type", "application/json")
            self.write(json.dumps(result))

        except json.JSONDecodeError:
            raise HTTPError(400, reason="Invalid JSON format.")
        except Exception as e:
            # 捕获批处理过程中可能发生的异常
            raise HTTPError(500, reason=f"Internal server error: {str(e)}")


async def main():
    # 1. 加载模型
    model_loader.load()

    # 2. 初始化并连接特征存储
    feature_store = FeatureStore()
    await feature_store.connect()

    # 3. 初始化并启动请求批处理器
    batcher = RequestBatcher(model_loader, feature_store)
    await batcher.start()

    # 4. 创建 Tornado 应用
    app = Application([
        (r"/predict", PredictionHandler, dict(batcher=batcher)),
    ])
    
    server = HTTPServer(app)
    server.bind(config.SERVER_PORT)
    # 在生产环境中,num_processes 应大于1,通常设为 CPU 核心数
    # Tornado 会自动处理多进程间的负载均衡
    server.start(config.SERVER_WORKERS)
    
    print(f"Server listening on http://localhost:{config.SERVER_PORT}")
    
    # 保持主进程运行并处理关闭信号
    shutdown_event = asyncio.Event()
    await shutdown_event.wait()


if __name__ == "__main__":
    # 运行主程序
    try:
        IOLoop.current().run_sync(main)
    except KeyboardInterrupt:
        print("Server shutting down...")
        # 优雅停机的逻辑应该在这里实现,例如调用 batcher.stop() 和 feature_store.disconnect()
        # tornado 6+ 会自动处理

Tornado RequestHandler 的实现非常简洁。它的全部工作就是解析请求,创建一个 PredictionRequest 对象,然后 await batcher.submit_request 的结果。所有的复杂性都被封装在了 RequestBatcher 内部。

局限性与未来迭代路径

这套架构解决了 Scikit-learn 模型在原生 Python 环境下高并发部署的核心痛点,但它并非银弹。

首先,BATCH_TIMEOUT_MS 是一个需要精细调优的参数。设置太长会增加单个请求的延迟,太短则批处理效果不佳。在真实场景中,这个值甚至可以是动态的,根据当前系统的负载自适应调整。

其次,我们仍然受限于 Python GIL 和 ThreadPoolExecutor。虽然 run_in_executor 避免了阻塞事件循环,但它本质上是在多线程中执行计算。对于计算密集型任务,多线程的性能提升有限。若要进一步压榨性能,可以考虑将服务进程化,利用 tornado.process.fork_processes 或 Gunicorn 等工具启动多个独立的 Python 进程,每个进程拥有自己的 Tornado 实例和批处理器,从而真正利用多核 CPU。

再者,Scikit-learn 模型本身并非为高性能推理设计。如果延迟要求达到极致(例如亚毫秒级),正确的路径是将模型转换为更高效的推理格式,如 ONNX,并使用专门的 C++/Rust 推理引擎(如 ONNX Runtime, Triton Inference Server)来执行。我们这套架构可以作为这些专用推理服务器的前置代理,负责业务逻辑、特征拉取和请求批处理,然后通过 RPC 调用后端推理引擎。

最后,一个生产级的系统还需要完善的可观测性。需要在 RequestBatcher 中加入详细的监控指标,例如:输入 QPS、批次大小分布、批处理延迟、队列深度、线程池利用率等,通过 Prometheus 暴露出来,以便进行容量规划和故障排查。


  目录