Model Serving Architecture: From Batch to Real-Time Inference at Scale
Training a great model is only half the challenge. Serving that model to users reliably, with low latency, at scale — that's where most ML projects struggle. This guide covers the architectural patterns for both batch and real-time model serving.
Choosing Your Serving Pattern
| Pattern | Latency | Throughput | Use Case | |---------|---------|-----------|----------| | Batch | Minutes-hours | Very high | Reports, recommendations, scoring | | Real-time REST | 50-200ms | Medium | Web/mobile APIs | | Real-time gRPC | 10-50ms | High | Internal services | | Streaming | Sub-second | Continuous | Fraud detection, real-time pricing | | Edge | < 10ms | Per-device | IoT, mobile on-device |
Pattern 1: Real-Time Serving with FastAPI
For teams starting out or serving simpler models, a FastAPI-based serving layer provides maximum flexibility:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import mlflow.pyfunc
import numpy as np
import time
import logging
from contextlib import asynccontextmanager
from prometheus_client import Histogram, Counter, generate_latest
# Metrics
PREDICTION_LATENCY = Histogram("prediction_latency_seconds", "Time to generate prediction", ["model_name"])
PREDICTION_COUNT = Counter("prediction_total", "Total predictions", ["model_name", "status"])
# Model cache
models: dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load models on startup."""
models["churn"] = mlflow.pyfunc.load_model("models:/churn-predictor/Production")
models["pricing"] = mlflow.pyfunc.load_model("models:/dynamic-pricing/Production")
logging.info(f"Loaded {len(models)} models")
yield
models.clear()
app = FastAPI(title="ML Serving API", lifespan=lifespan)
class PredictionRequest(BaseModel):
features: dict[str, float] = Field(..., description="Feature name-value pairs")
model_name: str = Field(default="churn", description="Model to use for prediction")
class PredictionResponse(BaseModel):
prediction: float
probability: float | None = None
model_version: str
latency_ms: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
if request.model_name not in models:
raise HTTPException(404, f"Model '{request.model_name}' not found")
model = models[request.model_name]
start = time.perf_counter()
try:
import pandas as pd
df = pd.DataFrame([request.features])
result = model.predict(df)
latency_ms = (time.perf_counter() - start) * 1000
PREDICTION_LATENCY.labels(model_name=request.model_name).observe(latency_ms / 1000)
PREDICTION_COUNT.labels(model_name=request.model_name, status="success").inc()
return PredictionResponse(
prediction=float(result[0]),
model_version=getattr(model, "metadata", {}).get("run_id", "unknown"),
latency_ms=round(latency_ms, 2),
)
except Exception as e:
PREDICTION_COUNT.labels(model_name=request.model_name, status="error").inc()
raise HTTPException(500, f"Prediction failed: {str(e)}")
@app.get("/health")
async def health():
return {"status": "healthy", "models_loaded": list(models.keys())}
@app.get("/metrics")
async def metrics():
from starlette.responses import Response
return Response(content=generate_latest(), media_type="text/plain")Pattern 2: KServe on Kubernetes
For production-grade serving with auto-scaling, canary deployments, and GPU support:
# kserve-inference-service.yaml
apiVersion: serving.kserve.io/v1beta1
kind: InferenceService
metadata:
name: churn-predictor
namespace: ml-serving
annotations:
serving.kserve.io/autoscalerClass: hpa
spec:
predictor:
minReplicas: 2
maxReplicas: 10
scaleTarget: 70 # Scale up at 70% CPU
scaleMetric: cpu
model:
modelFormat:
name: mlflow
storageUri: "s3://ml-models/churn-predictor/v2.3"
resources:
requests:
cpu: "1"
memory: "2Gi"
limits:
cpu: "2"
memory: "4Gi"
# Canary traffic split
canaryTrafficPercent: 10
transformer:
containers:
- name: feature-transformer
image: registry.company.com/ml/feature-transformer:v1.2
resources:
requests:
cpu: "500m"
memory: "512Mi"KServe with Custom Predictor
import kserve
from kserve import Model, ModelServer
import joblib
import numpy as np
class ChurnPredictor(Model):
def __init__(self, name: str):
super().__init__(name)
self.model = None
def load(self):
self.model = joblib.load("/mnt/models/model.joblib")
self.ready = True
def predict(self, payload: dict, headers: dict = None) -> dict:
instances = payload.get("instances", [])
features = np.array(instances)
predictions = self.model.predict(features)
probabilities = self.model.predict_proba(features)
return {
"predictions": predictions.tolist(),
"probabilities": probabilities.tolist(),
}
if __name__ == "__main__":
model = ChurnPredictor("churn-predictor")
model.load()
ModelServer().start([model])Pattern 3: Batch Inference Pipeline
For use cases where real-time isn't required (recommendations, daily scoring):
from datetime import datetime
import pandas as pd
import mlflow.pyfunc
from concurrent.futures import ProcessPoolExecutor
class BatchInferencePipeline:
"""Score millions of records efficiently in batch."""
def __init__(self, model_name: str, model_stage: str = "Production"):
self.model = mlflow.pyfunc.load_model(f"models:/{model_name}/{model_stage}")
self.model_name = model_name
def score_batch(self, data: pd.DataFrame, batch_size: int = 10000) -> pd.DataFrame:
"""Score data in chunks to manage memory."""
results = []
total_rows = len(data)
for i in range(0, total_rows, batch_size):
chunk = data.iloc[i:i + batch_size]
predictions = self.model.predict(chunk)
chunk_result = chunk.copy()
chunk_result["prediction"] = predictions
chunk_result["scored_at"] = datetime.utcnow().isoformat()
chunk_result["model_name"] = self.model_name
results.append(chunk_result)
progress = min(i + batch_size, total_rows) / total_rows * 100
print(f"Progress: {progress:.1f}% ({min(i + batch_size, total_rows)}/{total_rows})")
return pd.concat(results, ignore_index=True)
# Usage
pipeline = BatchInferencePipeline("churn-predictor")
customers = pd.read_parquet("s3://data/customers/latest.parquet")
scored = pipeline.score_batch(customers)
scored.to_parquet(f"s3://predictions/churn/{datetime.utcnow().strftime('%Y-%m-%d')}.parquet")A/B Testing Model Versions
import hashlib
from typing import Literal
class ABRouter:
"""Deterministic A/B routing for model experiments."""
def __init__(self, experiments: dict[str, dict]):
"""
experiments = {
"control": {"model": model_a, "traffic": 0.80},
"treatment": {"model": model_b, "traffic": 0.20},
}
"""
self.experiments = experiments
self._validate_traffic()
def _validate_traffic(self):
total = sum(e["traffic"] for e in self.experiments.values())
assert abs(total - 1.0) < 0.01, f"Traffic must sum to 1.0, got {total}"
def route(self, entity_id: str) -> tuple[str, object]:
"""Deterministically assign entity to experiment variant."""
# Hash entity ID for consistent assignment
hash_val = int(hashlib.sha256(entity_id.encode()).hexdigest(), 16) % 10000
normalized = hash_val / 10000
cumulative = 0.0
for variant_name, config in self.experiments.items():
cumulative += config["traffic"]
if normalized < cumulative:
return variant_name, config["model"]
# Fallback to last variant
last_key = list(self.experiments.keys())[-1]
return last_key, self.experiments[last_key]["model"]Model Optimization for Serving
Quantization and ONNX Export
import torch
import onnx
import onnxruntime as ort
def optimize_pytorch_model(model, sample_input, output_path: str):
"""Export PyTorch model to ONNX for faster inference."""
model.eval()
# Export to ONNX
torch.onnx.export(
model,
sample_input,
output_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
# Validate
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
return output_path
def serve_onnx(model_path: str, features: np.ndarray) -> np.ndarray:
"""Run inference with ONNX Runtime — typically 2-5x faster than PyTorch."""
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: features.astype(np.float32)})
return result[0]Serving Infrastructure Comparison
| Framework | Protocol | Auto-scaling | GPU | Multi-model | Best For | |-----------|----------|-------------|-----|-------------|----------| | FastAPI | REST | Manual/K8s HPA | Manual | Manual | Simple, flexible | | KServe | REST/gRPC | Built-in | Native | Built-in | K8s production | | Seldon Core | REST/gRPC | Built-in | Native | Built-in | Complex graphs | | BentoML | REST/gRPC | BentoCloud | Supported | Built-in | Easy packaging | | TF Serving | REST/gRPC | Manual | Native | Built-in | TensorFlow only | | Triton | REST/gRPC | Built-in | Optimized | Built-in | GPU inference |
Monitoring Serving Infrastructure
Key metrics to track in production:
# Prometheus alerts for model serving
groups:
- name: ml-serving
rules:
- alert: HighPredictionLatency
expr: histogram_quantile(0.99, prediction_latency_seconds) > 0.2
for: 5m
labels:
severity: warning
annotations:
summary: "P99 prediction latency exceeds 200ms"
- alert: HighErrorRate
expr: rate(prediction_total{status="error"}[5m]) / rate(prediction_total[5m]) > 0.01
for: 2m
labels:
severity: critical
annotations:
summary: "Model error rate exceeds 1%"
- alert: LowThroughput
expr: rate(prediction_total[5m]) < 1
for: 10m
labels:
severity: warning
annotations:
summary: "Prediction throughput dropped below 1 req/s"Related Resources
- MLOps overview — Where model serving fits in the MLOps lifecycle
- ML CI/CD — Automate model deployment
- Model monitoring — Monitor served models for degradation
- Kubeflow Pipelines — Orchestrate training that feeds your serving layer
Need production model serving? DeviDevs builds scalable ML serving infrastructure with auto-scaling, A/B testing, and monitoring. Get a free assessment →