Model Serving Architecture: From Batch to Real-Time Inference at Scale
Training a great model is only half the challenge. Serving that model to users reliably, with low latency, at scale, that's where most ML projects struggle. This guide covers the architectural patterns for both batch and real-time model serving.
Choosing Your Serving Pattern
| Pattern | Latency | Throughput | Use Case | |---------|---------|-----------|----------| | Batch | Minutes-hours | Very high | Reports, recommendations, scoring | | Real-time REST | 50-200ms | Medium | Web/mobile APIs | | Real-time gRPC | 10-50ms | High | Internal services | | Streaming | Sub-second | Continuous | Fraud detection, real-time pricing | | Edge | < 10ms | Per-device | IoT, mobile on-device |
Pattern 1: Real-Time Serving with FastAPI
For teams starting out or serving simpler models, a FastAPI-based serving layer provides maximum flexibility:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import mlflow.pyfunc
import numpy as np
import time
import logging
from contextlib import asynccontextmanager
from prometheus_client import Histogram, Counter, generate_latest
# Metrics
PREDICTION_LATENCY = Histogram("prediction_latency_seconds", "Time to generate prediction", ["model_name"])
PREDICTION_COUNT = Counter("prediction_total", "Total predictions", ["model_name", "status"])
# Model cache
models: dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup."""
models["churn"] = mlflow.pyfunc.load_model("models:/churn-predictor/Production")
models["pricing"] = mlflow.pyfunc.load_model("models:/dynamic-pricing/Production")
logging.info(f"Loaded {len(models)} models")
yield
models.clear()
app = FastAPI(title="ML Serving API", lifespan=lifespan)
class PredictionRequest(BaseModel):
features: dict[str, float] = Field(..., description="Feature name-value pairs")
model_name: str = Field(default="churn", description="Model to use for prediction")
class PredictionResponse(BaseModel):
prediction: float
probability: float | None = None
model_version: str
latency_ms: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
if request.model_name not in models:
raise HTTPException(404, f"Model '{request.model_name}' not found")
model = models[request.model_name]
start = time.perf_counter()
try:
import pandas as pd
df = pd.DataFrame([request.features])
result = model.predict(df)
latency_ms = (time.perf_counter() - start) * 1000
PREDICTION_LATENCY.labels(model_name=request.model_name).observe(latency_ms / 1000)
PREDICTION_COUNT.labels(model_name=request.model_name, status="success").inc()
return PredictionResponse(
prediction=float(result[0]),
model_version=getattr(model, "metadata", {}).get("run_id", "unknown"),
latency_ms=round(latency_ms, 2),
)
except Exception as e:
PREDICTION_COUNT.labels(model_name=request.model_name, status="error").inc()
raise HTTPException(500, f"Prediction failed: {str(e)}")
@app.get("/health")
async def health():
return {"status": "healthy", "models_loaded": list(models.keys())}
@app.get("/metrics")
async def metrics():
from starlette.responses import Response
return Response(content=generate_latest(), media_type="text/plain")Pattern 2: KServe on Kubernetes
For production-grade serving with auto-scaling, canary deployments, and GPU support:
# kserve-inference-service.yaml
apiVersion: serving.kserve.io/v1beta1
kind: InferenceService
metadata:
name: churn-predictor
namespace: ml-serving
annotations:
serving.kserve.io/autoscalerClass: hpa
spec:
predictor:
minReplicas: 2
maxReplicas: 10
scaleTarget: 70 # Scale up at 70% CPU
scaleMetric: cpu
model:
modelFormat:
name: mlflow
storageUri: "s3://ml-models/churn-predictor/v2.3"
resources:
requests:
cpu: "1"
memory: "2Gi"
limits:
cpu: "2"
memory: "4Gi"
# Canary traffic split
canaryTrafficPercent: 10
transformer:
containers:
- name: feature-transformer
image: registry.company.com/ml/feature-transformer:v1.2
resources:
requests:
cpu: "500m"
memory: "512Mi"KServe with Custom Predictor
import kserve
from kserve import Model, ModelServer
import joblib
import numpy as np
class ChurnPredictor(Model):
def __init__(self, name: str):
super().__init__(name)
self.model = None
def load(self):
self.model = joblib.load("/mnt/models/model.joblib")
self.ready = True
def predict(self, payload: dict, headers: dict = None) -> dict:
instances = payload.get("instances", [])
features = np.array(instances)
predictions = self.model.predict(features)
probabilities = self.model.predict_proba(features)
return {
"predictions": predictions.tolist(),
"probabilities": probabilities.tolist(),
}
if __name__ == "__main__":
model = ChurnPredictor("churn-predictor")
model.load()
ModelServer().start([model])Pattern 3: Batch Inference Pipeline
For use cases where real-time isn't required (recommendations, daily scoring):
from datetime import datetime
import pandas as pd
import mlflow.pyfunc
from concurrent.futures import ProcessPoolExecutor
class BatchInferencePipeline:
"""Score millions of records efficiently in batch."""
def __init__(self, model_name: str, model_stage: str = "Production"):
self.model = mlflow.pyfunc.load_model(f"models:/{model_name}/{model_stage}")
self.model_name = model_name
def score_batch(self, data: pd.DataFrame, batch_size: int = 10000) -> pd.DataFrame:
"""Score data in chunks to manage memory."""
results = []
total_rows = len(data)
for i in range(0, total_rows, batch_size):
chunk = data.iloc[i:i + batch_size]
predictions = self.model.predict(chunk)
chunk_result = chunk.copy()
chunk_result["prediction"] = predictions
chunk_result["scored_at"] = datetime.utcnow().isoformat()
chunk_result["model_name"] = self.model_name
results.append(chunk_result)
progress = min(i + batch_size, total_rows) / total_rows * 100
print(f"Progress: {progress:.1f}% ({min(i + batch_size, total_rows)}/{total_rows})")
return pd.concat(results, ignore_index=True)
# Usage
pipeline = BatchInferencePipeline("churn-predictor")
customers = pd.read_parquet("s3://data/customers/latest.parquet")
scored = pipeline.score_batch(customers)
scored.to_parquet(f"s3://predictions/churn/{datetime.utcnow().strftime('%Y-%m-%d')}.parquet")A/B Testing Model Versions
import hashlib
from typing import Literal
class ABRouter:
"""Deterministic A/B routing for model experiments."""
def __init__(self, experiments: dict[str, dict]):
"""
experiments = {
"control": {"model": model_a, "traffic": 0.80},
"treatment": {"model": model_b, "traffic": 0.20},
}
"""
self.experiments = experiments
self._validate_traffic()
def _validate_traffic(self):
total = sum(e["traffic"] for e in self.experiments.values())
assert abs(total - 1.0) < 0.01, f"Traffic must sum to 1.0, got {total}"
def route(self, entity_id: str) -> tuple[str, object]:
"""Deterministically assign entity to experiment variant."""
# Hash entity ID for consistent assignment
hash_val = int(hashlib.sha256(entity_id.encode()).hexdigest(), 16) % 10000
normalized = hash_val / 10000
cumulative = 0.0
for variant_name, config in self.experiments.items():
cumulative += config["traffic"]
if normalized < cumulative:
return variant_name, config["model"]
# Fallback to last variant
last_key = list(self.experiments.keys())[-1]
return last_key, self.experiments[last_key]["model"]Model Optimization for Serving
Quantization and ONNX Export
import torch
import onnx
import onnxruntime as ort
def optimize_pytorch_model(model, sample_input, output_path: str):
"""Export PyTorch model to ONNX for faster inference."""
model.eval()
# Export to ONNX
torch.onnx.export(
model,
sample_input,
output_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
# Validate
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
return output_path
def serve_onnx(model_path: str, features: np.ndarray) -> np.ndarray:
"""Run inference with ONNX Runtime, typically 2-5x faster than PyTorch."""
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: features.astype(np.float32)})
return result[0]Serving Infrastructure Comparison
| Framework | Protocol | Auto-scaling | GPU | Multi-model | Best For | |-----------|----------|-------------|-----|-------------|----------| | FastAPI | REST | Manual/K8s HPA | Manual | Manual | Simple, flexible | | KServe | REST/gRPC | Built-in | Native | Built-in | K8s production | | Seldon Core | REST/gRPC | Built-in | Native | Built-in | Complex graphs | | BentoML | REST/gRPC | BentoCloud | Supported | Built-in | Easy packaging | | TF Serving | REST/gRPC | Manual | Native | Built-in | TensorFlow only | | Triton | REST/gRPC | Built-in | Optimized | Built-in | GPU inference |
Monitoring Serving Infrastructure
Key metrics to track in production:
# Prometheus alerts for model serving
groups:
- name: ml-serving
rules:
- alert: HighPredictionLatency
expr: histogram_quantile(0.99, prediction_latency_seconds) > 0.2
for: 5m
labels:
severity: warning
annotations:
summary: "P99 prediction latency exceeds 200ms"
- alert: HighErrorRate
expr: rate(prediction_total{status="error"}[5m]) / rate(prediction_total[5m]) > 0.01
for: 2m
labels:
severity: critical
annotations:
summary: "Model error rate exceeds 1%"
- alert: LowThroughput
expr: rate(prediction_total[5m]) < 1
for: 10m
labels:
severity: warning
annotations:
summary: "Prediction throughput dropped below 1 req/s"Related Resources
- MLOps overview: Where model serving fits in the MLOps lifecycle
- ML CI/CD: Automate model deployment
- Model monitoring: Monitor served models for degradation
- Kubeflow Pipelines: Orchestrate training that feeds your serving layer
Need production model serving? DeviDevs builds scalable ML serving infrastructure with auto-scaling, A/B testing, and monitoring. Get a free assessment →