AI Model Security Deployment Checklist: From Development to Production
Deploying AI models to production introduces security considerations beyond traditional software deployments. Models can be stolen, manipulated, or abused in ways that traditional security controls don't address.
This checklist provides a comprehensive framework for secure AI model deployment.
Pre-Deployment Security Assessment
Model Security Review
## Model Security Checklist
### Training Data Security
- [ ] Training data sources documented and validated
- [ ] No unauthorized PII in training data
- [ ] Data poisoning risks assessed and mitigated
- [ ] Training data access logs maintained
- [ ] Sensitive data properly anonymized or removed
### Model Integrity
- [ ] Model weights checksummed and verified
- [ ] Model provenance documented
- [ ] No backdoors introduced during training
- [ ] Model behavior validated against expected outputs
- [ ] Adversarial robustness testing completed
### Bias and Fairness
- [ ] Bias assessment completed across protected groups
- [ ] Fairness metrics documented
- [ ] Mitigation strategies implemented where needed
- [ ] Ongoing bias monitoring plan in place
### Intellectual Property
- [ ] Model licensing terms understood
- [ ] Third-party model usage properly licensed
- [ ] Proprietary model protection measures implemented
- [ ] Model theft detection capabilities plannedSecurity Architecture Review
class SecurityArchitectureReview:
"""Framework for reviewing AI deployment security architecture."""
def __init__(self):
self.checklist_items = {
'network_security': [
('api_gateway', 'API endpoints protected by gateway'),
('tls_encryption', 'TLS 1.3 for all communications'),
('network_segmentation', 'Model serving isolated from other systems'),
('ddos_protection', 'DDoS mitigation in place'),
('ip_allowlisting', 'IP allowlisting for sensitive endpoints'),
],
'authentication': [
('api_keys', 'API key authentication implemented'),
('oauth2', 'OAuth2 for user-facing applications'),
('key_rotation', 'Automatic key rotation configured'),
('mfa', 'MFA for administrative access'),
],
'authorization': [
('rbac', 'Role-based access control implemented'),
('least_privilege', 'Minimum necessary permissions granted'),
('api_scopes', 'API scopes defined and enforced'),
('resource_policies', 'Resource-level policies in place'),
],
'data_protection': [
('encryption_at_rest', 'Model weights encrypted at rest'),
('encryption_in_transit', 'All data encrypted in transit'),
('input_sanitization', 'Input sanitization implemented'),
('output_filtering', 'Output filtering in place'),
('pii_handling', 'PII handling procedures defined'),
],
'monitoring': [
('access_logging', 'All API access logged'),
('anomaly_detection', 'Anomaly detection configured'),
('performance_monitoring', 'Performance metrics tracked'),
('security_alerting', 'Security alerts configured'),
],
}
def generate_review_report(self, assessment: dict) -> dict:
"""Generate security architecture review report."""
report = {
'summary': {},
'findings': [],
'recommendations': []
}
for category, items in self.checklist_items.items():
category_score = 0
for item_id, description in items:
status = assessment.get(category, {}).get(item_id, 'not_assessed')
if status == 'implemented':
category_score += 1
elif status == 'partial':
category_score += 0.5
report['findings'].append({
'category': category,
'item': item_id,
'status': 'partial',
'description': description
})
elif status == 'not_implemented':
report['findings'].append({
'category': category,
'item': item_id,
'status': 'missing',
'description': description,
'risk': self._assess_risk(category, item_id)
})
report['summary'][category] = {
'score': category_score / len(items),
'items_total': len(items),
'items_implemented': category_score
}
return reportInfrastructure Security
Container and Orchestration Security
# Kubernetes security configurations for AI model serving
apiVersion: v1
kind: Pod
metadata:
name: model-serving
annotations:
seccomp.security.alpha.kubernetes.io/pod: runtime/default
spec:
securityContext:
runAsNonRoot: true
runAsUser: 1000
fsGroup: 1000
containers:
- name: model-server
image: your-registry/model-server:v1.0.0
imagePullPolicy: Always
securityContext:
allowPrivilegeEscalation: false
readOnlyRootFilesystem: true
capabilities:
drop:
- ALL
resources:
limits:
memory: "8Gi"
cpu: "4"
nvidia.com/gpu: "1"
requests:
memory: "4Gi"
cpu: "2"
volumeMounts:
- name: model-weights
mountPath: /models
readOnly: true
- name: tmp
mountPath: /tmp
livenessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 30
periodSeconds: 10
env:
- name: MODEL_PATH
value: /models/production
- name: API_KEY
valueFrom:
secretKeyRef:
name: model-api-secrets
key: api-key
volumes:
- name: model-weights
persistentVolumeClaim:
claimName: model-weights-pvc
- name: tmp
emptyDir: {}
---
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: model-serving-policy
spec:
podSelector:
matchLabels:
app: model-serving
policyTypes:
- Ingress
- Egress
ingress:
- from:
- podSelector:
matchLabels:
app: api-gateway
ports:
- protocol: TCP
port: 8080
egress:
- to:
- podSelector:
matchLabels:
app: logging
ports:
- protocol: TCP
port: 443Model Weight Protection
import hashlib
from cryptography.fernet import Fernet
from pathlib import Path
class ModelWeightProtection:
"""Protect model weights at rest and in transit."""
def __init__(self, encryption_key: bytes):
self.cipher = Fernet(encryption_key)
self.integrity_db = {}
def encrypt_model(self, model_path: Path, output_path: Path) -> dict:
"""Encrypt model weights for secure storage."""
# Read model file
with open(model_path, 'rb') as f:
model_data = f.read()
# Calculate integrity hash before encryption
original_hash = hashlib.sha256(model_data).hexdigest()
# Encrypt
encrypted_data = self.cipher.encrypt(model_data)
# Write encrypted model
with open(output_path, 'wb') as f:
f.write(encrypted_data)
# Store integrity information
self.integrity_db[str(output_path)] = {
'original_hash': original_hash,
'encrypted_hash': hashlib.sha256(encrypted_data).hexdigest(),
'encrypted_at': datetime.utcnow().isoformat()
}
return {
'original_path': str(model_path),
'encrypted_path': str(output_path),
'original_hash': original_hash,
'original_size': len(model_data),
'encrypted_size': len(encrypted_data)
}
def decrypt_and_verify(self, encrypted_path: Path) -> bytes:
"""Decrypt model weights and verify integrity."""
# Read encrypted model
with open(encrypted_path, 'rb') as f:
encrypted_data = f.read()
# Verify encrypted file integrity
expected_encrypted_hash = self.integrity_db.get(
str(encrypted_path), {}
).get('encrypted_hash')
if expected_encrypted_hash:
actual_hash = hashlib.sha256(encrypted_data).hexdigest()
if actual_hash != expected_encrypted_hash:
raise IntegrityError("Encrypted model file has been tampered with")
# Decrypt
model_data = self.cipher.decrypt(encrypted_data)
# Verify decrypted content integrity
expected_original_hash = self.integrity_db.get(
str(encrypted_path), {}
).get('original_hash')
if expected_original_hash:
actual_hash = hashlib.sha256(model_data).hexdigest()
if actual_hash != expected_original_hash:
raise IntegrityError("Decrypted model doesn't match original")
return model_data
def verify_integrity(self, model_path: Path) -> bool:
"""Verify model file integrity without decryption."""
with open(model_path, 'rb') as f:
data = f.read()
actual_hash = hashlib.sha256(data).hexdigest()
expected_hash = self.integrity_db.get(str(model_path), {}).get('encrypted_hash')
return actual_hash == expected_hashAPI Security
Authentication and Authorization
from fastapi import FastAPI, Security, HTTPException, Depends
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
import jwt
app = FastAPI()
class APISecurityManager:
"""Manage API security for model serving endpoints."""
def __init__(self, config: dict):
self.api_key_header = APIKeyHeader(name="X-API-Key")
self.oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
self.jwt_secret = config['jwt_secret']
self.api_keys = self._load_api_keys()
self.rate_limiter = RateLimiter(config['rate_limits'])
async def verify_api_key(self, api_key: str = Security(APIKeyHeader(name="X-API-Key"))) -> dict:
"""Verify API key and return associated permissions."""
if api_key not in self.api_keys:
raise HTTPException(status_code=401, detail="Invalid API key")
key_info = self.api_keys[api_key]
# Check if key is active
if not key_info.get('active', True):
raise HTTPException(status_code=401, detail="API key is disabled")
# Check expiration
if key_info.get('expires_at'):
if datetime.utcnow() > key_info['expires_at']:
raise HTTPException(status_code=401, detail="API key has expired")
return {
'client_id': key_info['client_id'],
'scopes': key_info.get('scopes', []),
'rate_limit_tier': key_info.get('rate_limit_tier', 'default')
}
async def verify_jwt(self, token: str = Depends(OAuth2PasswordBearer(tokenUrl="token"))) -> dict:
"""Verify JWT token and extract claims."""
try:
payload = jwt.decode(
token,
self.jwt_secret,
algorithms=['HS256'],
options={'require': ['exp', 'sub', 'scopes']}
)
return {
'user_id': payload['sub'],
'scopes': payload['scopes'],
'exp': payload['exp']
}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError as e:
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}")
def require_scope(self, required_scope: str):
"""Decorator to require specific scope for endpoint."""
def decorator(func):
async def wrapper(*args, auth_info: dict = None, **kwargs):
if required_scope not in auth_info.get('scopes', []):
raise HTTPException(
status_code=403,
detail=f"Missing required scope: {required_scope}"
)
return await func(*args, **kwargs)
return wrapper
return decorator
# Usage in FastAPI
security_manager = APISecurityManager(config)
@app.post("/v1/inference")
async def run_inference(
request: InferenceRequest,
auth: dict = Depends(security_manager.verify_api_key)
):
"""Run model inference with API key authentication."""
# Check rate limits
if not security_manager.rate_limiter.check_allowed(
auth['client_id'],
auth['rate_limit_tier']
):
raise HTTPException(status_code=429, detail="Rate limit exceeded")
# Process inference
result = await model.predict(request.input)
return {"result": result}Input Validation
from pydantic import BaseModel, validator, Field
from typing import List, Optional
import re
class InferenceRequest(BaseModel):
"""Validated inference request model."""
input_text: str = Field(..., min_length=1, max_length=10000)
parameters: Optional[dict] = Field(default_factory=dict)
session_id: Optional[str] = Field(None, regex=r'^[a-zA-Z0-9-]{36}$')
@validator('input_text')
def validate_input_text(cls, v):
# Check for null bytes
if '\x00' in v:
raise ValueError('Input contains null bytes')
# Check for excessive special characters
special_char_ratio = len(re.findall(r'[^\w\s]', v)) / len(v) if v else 0
if special_char_ratio > 0.5:
raise ValueError('Input contains too many special characters')
return v
@validator('parameters')
def validate_parameters(cls, v):
allowed_params = {'temperature', 'max_tokens', 'top_p', 'top_k'}
for key in v.keys():
if key not in allowed_params:
raise ValueError(f'Unknown parameter: {key}')
# Validate parameter ranges
if 'temperature' in v:
if not 0 <= v['temperature'] <= 2:
raise ValueError('Temperature must be between 0 and 2')
if 'max_tokens' in v:
if not 1 <= v['max_tokens'] <= 4096:
raise ValueError('max_tokens must be between 1 and 4096')
return v
class InputSanitizer:
"""Sanitize and validate model inputs."""
def __init__(self, config: dict):
self.max_length = config.get('max_input_length', 10000)
self.injection_patterns = self._load_injection_patterns()
def sanitize(self, input_text: str) -> str:
"""Sanitize input text for safe model consumption."""
# Truncate if too long
if len(input_text) > self.max_length:
input_text = input_text[:self.max_length]
# Remove null bytes and control characters
input_text = input_text.replace('\x00', '')
input_text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', input_text)
# Normalize unicode
input_text = unicodedata.normalize('NFKC', input_text)
return input_text
def check_injection(self, input_text: str) -> dict:
"""Check for prompt injection attempts."""
detections = []
for pattern_name, pattern in self.injection_patterns.items():
if re.search(pattern, input_text, re.IGNORECASE):
detections.append(pattern_name)
return {
'is_suspicious': len(detections) > 0,
'patterns_detected': detections
}Monitoring and Logging
Security Monitoring
import structlog
from dataclasses import dataclass
from typing import Optional
@dataclass
class SecurityEvent:
event_type: str
severity: str
user_id: Optional[str]
client_id: Optional[str]
details: dict
timestamp: str
class SecurityMonitor:
"""Monitor AI model serving for security events."""
def __init__(self, config: dict):
self.logger = structlog.get_logger()
self.alert_thresholds = config.get('alert_thresholds', {})
self.metrics_collector = MetricsCollector()
def log_inference_request(self, request: dict, response: dict,
auth_info: dict, duration_ms: float):
"""Log inference request with security-relevant details."""
# Extract security-relevant metrics
input_length = len(request.get('input_text', ''))
output_length = len(str(response.get('result', '')))
self.logger.info(
"inference_request",
client_id=auth_info.get('client_id'),
input_length=input_length,
output_length=output_length,
duration_ms=duration_ms,
parameters=request.get('parameters', {}),
# Don't log actual input/output for privacy
)
# Collect metrics
self.metrics_collector.record({
'inference_count': 1,
'inference_latency_ms': duration_ms,
'input_tokens': input_length // 4, # Rough estimate
'output_tokens': output_length // 4,
}, labels={'client_id': auth_info.get('client_id')})
def log_security_event(self, event: SecurityEvent):
"""Log security-relevant events."""
log_method = {
'critical': self.logger.critical,
'high': self.logger.error,
'medium': self.logger.warning,
'low': self.logger.info,
}.get(event.severity, self.logger.info)
log_method(
"security_event",
event_type=event.event_type,
severity=event.severity,
user_id=event.user_id,
client_id=event.client_id,
details=event.details,
timestamp=event.timestamp
)
# Check if alert needed
if event.severity in ['critical', 'high']:
self._send_alert(event)
def detect_anomalies(self, client_id: str, metrics: dict) -> list:
"""Detect anomalous behavior patterns."""
anomalies = []
# Check request rate anomaly
if metrics.get('requests_per_minute', 0) > self.alert_thresholds.get('rpm', 100):
anomalies.append({
'type': 'high_request_rate',
'value': metrics['requests_per_minute'],
'threshold': self.alert_thresholds['rpm']
})
# Check for unusual input patterns
if metrics.get('avg_input_length', 0) > self.alert_thresholds.get('input_length', 5000):
anomalies.append({
'type': 'large_inputs',
'value': metrics['avg_input_length'],
'threshold': self.alert_thresholds['input_length']
})
# Check for high error rate
error_rate = metrics.get('error_count', 0) / max(metrics.get('request_count', 1), 1)
if error_rate > self.alert_thresholds.get('error_rate', 0.1):
anomalies.append({
'type': 'high_error_rate',
'value': error_rate,
'threshold': self.alert_thresholds['error_rate']
})
return anomalies
def _send_alert(self, event: SecurityEvent):
"""Send security alert through configured channels."""
# Implementation depends on alerting infrastructure
passAudit Logging
class AuditLogger:
"""Comprehensive audit logging for AI model serving."""
def __init__(self, config: dict):
self.storage = config.get('storage', 'elasticsearch')
self.retention_days = config.get('retention_days', 90)
self.logger = self._init_logger()
def log_access(self, event_type: str, details: dict):
"""Log access events."""
audit_entry = {
'timestamp': datetime.utcnow().isoformat(),
'event_type': event_type,
'details': details,
'log_type': 'access'
}
# Add immutability hash
audit_entry['integrity_hash'] = self._compute_hash(audit_entry)
self.logger.info(audit_entry)
def log_admin_action(self, action: str, user: str, details: dict):
"""Log administrative actions."""
audit_entry = {
'timestamp': datetime.utcnow().isoformat(),
'action': action,
'admin_user': user,
'details': details,
'log_type': 'admin',
'integrity_hash': None
}
audit_entry['integrity_hash'] = self._compute_hash(audit_entry)
self.logger.info(audit_entry)
def log_model_change(self, change_type: str, model_info: dict,
user: str, reason: str):
"""Log model deployment and configuration changes."""
audit_entry = {
'timestamp': datetime.utcnow().isoformat(),
'change_type': change_type,
'model_id': model_info.get('model_id'),
'model_version': model_info.get('version'),
'changed_by': user,
'reason': reason,
'previous_state': model_info.get('previous_state'),
'new_state': model_info.get('new_state'),
'log_type': 'model_change'
}
audit_entry['integrity_hash'] = self._compute_hash(audit_entry)
self.logger.info(audit_entry)
def _compute_hash(self, entry: dict) -> str:
"""Compute integrity hash for audit entry."""
entry_copy = entry.copy()
entry_copy.pop('integrity_hash', None)
content = json.dumps(entry_copy, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()Deployment Checklist Summary
# AI Model Deployment Security Checklist
## Pre-Deployment
- [ ] Model security review completed
- [ ] Training data security validated
- [ ] Bias assessment completed
- [ ] Security architecture reviewed
- [ ] Penetration testing completed
## Infrastructure
- [ ] Network segmentation implemented
- [ ] TLS 1.3 configured
- [ ] Container security hardened
- [ ] Resource limits defined
- [ ] Model weights encrypted
## Access Control
- [ ] API authentication implemented
- [ ] RBAC configured
- [ ] API scopes defined
- [ ] Key rotation automated
- [ ] MFA for admin access
## Input/Output Security
- [ ] Input validation implemented
- [ ] Injection detection enabled
- [ ] Output filtering configured
- [ ] Rate limiting active
- [ ] PII handling defined
## Monitoring
- [ ] Access logging enabled
- [ ] Security alerting configured
- [ ] Anomaly detection active
- [ ] Audit trail complete
- [ ] Metrics collection operational
## Incident Response
- [ ] IR playbook documented
- [ ] Rollback procedures tested
- [ ] Contact list updated
- [ ] Communication templates ready
## Compliance
- [ ] Data privacy requirements met
- [ ] Industry regulations addressed
- [ ] Documentation complete
- [ ] Regular audits scheduledConclusion
Secure AI model deployment requires attention to multiple layers of security, from infrastructure hardening to API protection to comprehensive monitoring. This checklist provides a foundation, but should be customized based on your specific deployment environment and risk profile.
At DeviDevs, we help organizations deploy AI models securely with comprehensive security assessments and implementation support. Contact us to discuss your AI deployment security needs.