AI Model Versioning and Security: Secure MLOps Model Management
Proper model versioning is critical for secure AI deployments. This guide covers implementing secure model management with integrity verification, access controls, and comprehensive audit trails.
Secure Model Registry
Model Registry Implementation
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Any
from enum import Enum
import hashlib
import json
import hmac
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives import serialization
class ModelStatus(Enum):
DEVELOPMENT = "development"
STAGING = "staging"
PRODUCTION = "production"
DEPRECATED = "deprecated"
ARCHIVED = "archived"
class ModelType(Enum):
CLASSIFICATION = "classification"
REGRESSION = "regression"
NLP = "nlp"
COMPUTER_VISION = "computer_vision"
GENERATIVE = "generative"
RECOMMENDATION = "recommendation"
@dataclass
class ModelMetadata:
name: str
version: str
model_type: ModelType
framework: str # pytorch, tensorflow, sklearn, etc.
description: str
created_by: str
created_at: datetime
input_schema: Dict[str, Any]
output_schema: Dict[str, Any]
training_dataset: str
training_parameters: Dict[str, Any]
metrics: Dict[str, float]
tags: List[str] = field(default_factory=list)
dependencies: Dict[str, str] = field(default_factory=dict)
@dataclass
class ModelArtifact:
artifact_id: str
model_path: str
checksum: str
size_bytes: int
encrypted: bool = False
encryption_key_id: Optional[str] = None
signature: Optional[str] = None
signed_by: Optional[str] = None
@dataclass
class ModelVersion:
version_id: str
metadata: ModelMetadata
artifact: ModelArtifact
status: ModelStatus
promoted_at: Optional[datetime] = None
promoted_by: Optional[str] = None
parent_version: Optional[str] = None
lineage: List[str] = field(default_factory=list)
class SecureModelRegistry:
def __init__(self, storage_backend, encryption_key: bytes):
self.storage = storage_backend
self.cipher = Fernet(encryption_key)
self.models: Dict[str, Dict[str, ModelVersion]] = {}
self.signing_keys: Dict[str, rsa.RSAPrivateKey] = {}
def register_model(self,
metadata: ModelMetadata,
model_bytes: bytes,
user_id: str,
sign: bool = True) -> ModelVersion:
"""Register a new model version with security controls."""
# Generate version ID
version_id = self._generate_version_id(metadata.name, metadata.version)
# Calculate checksum for integrity
checksum = self._calculate_checksum(model_bytes)
# Encrypt model artifact
encrypted_bytes = self.cipher.encrypt(model_bytes)
encryption_key_id = self._get_current_key_id()
# Store encrypted model
model_path = self._store_model(
metadata.name,
metadata.version,
encrypted_bytes
)
# Create artifact record
artifact = ModelArtifact(
artifact_id=f"artifact_{version_id}",
model_path=model_path,
checksum=checksum,
size_bytes=len(model_bytes),
encrypted=True,
encryption_key_id=encryption_key_id
)
# Sign artifact if requested
if sign:
signature = self._sign_artifact(artifact, user_id)
artifact.signature = signature
artifact.signed_by = user_id
# Create version record
version = ModelVersion(
version_id=version_id,
metadata=metadata,
artifact=artifact,
status=ModelStatus.DEVELOPMENT,
lineage=self._build_lineage(metadata.name, metadata.version)
)
# Store version
if metadata.name not in self.models:
self.models[metadata.name] = {}
self.models[metadata.name][metadata.version] = version
# Log registration
self._audit_log(
action="MODEL_REGISTERED",
model_name=metadata.name,
version=metadata.version,
user_id=user_id,
details={
"checksum": checksum,
"size_bytes": len(model_bytes),
"signed": sign
}
)
return version
def get_model(self,
name: str,
version: str,
user_id: str,
verify_integrity: bool = True) -> tuple:
"""Retrieve model with integrity verification."""
model_versions = self.models.get(name, {})
model_version = model_versions.get(version)
if not model_version:
raise ValueError(f"Model {name}:{version} not found")
# Check access permissions
if not self._check_access(name, version, user_id, "read"):
raise PermissionError(f"User {user_id} not authorized to access {name}:{version}")
# Retrieve encrypted model
encrypted_bytes = self._retrieve_model(model_version.artifact.model_path)
# Decrypt model
model_bytes = self.cipher.decrypt(encrypted_bytes)
# Verify integrity
if verify_integrity:
current_checksum = self._calculate_checksum(model_bytes)
if current_checksum != model_version.artifact.checksum:
self._audit_log(
action="INTEGRITY_VIOLATION",
model_name=name,
version=version,
user_id=user_id,
details={
"expected_checksum": model_version.artifact.checksum,
"actual_checksum": current_checksum
}
)
raise SecurityError(f"Model integrity check failed for {name}:{version}")
# Verify signature if present
if model_version.artifact.signature:
if not self._verify_signature(model_version.artifact):
raise SecurityError(f"Model signature verification failed for {name}:{version}")
# Log access
self._audit_log(
action="MODEL_ACCESSED",
model_name=name,
version=version,
user_id=user_id,
details={"integrity_verified": verify_integrity}
)
return model_bytes, model_version.metadata
def promote_model(self,
name: str,
version: str,
target_status: ModelStatus,
user_id: str,
approval_ticket: Optional[str] = None) -> ModelVersion:
"""Promote model to a new status with approval workflow."""
model_version = self.models.get(name, {}).get(version)
if not model_version:
raise ValueError(f"Model {name}:{version} not found")
# Check promotion permissions
if not self._check_access(name, version, user_id, "promote"):
raise PermissionError(f"User {user_id} not authorized to promote {name}:{version}")
# Validate promotion path
valid_promotions = {
ModelStatus.DEVELOPMENT: [ModelStatus.STAGING],
ModelStatus.STAGING: [ModelStatus.PRODUCTION, ModelStatus.DEVELOPMENT],
ModelStatus.PRODUCTION: [ModelStatus.DEPRECATED],
ModelStatus.DEPRECATED: [ModelStatus.ARCHIVED]
}
if target_status not in valid_promotions.get(model_version.status, []):
raise ValueError(
f"Cannot promote from {model_version.status.value} to {target_status.value}"
)
# Production promotion requires approval
if target_status == ModelStatus.PRODUCTION:
if not approval_ticket:
raise ValueError("Production promotion requires approval ticket")
if not self._verify_approval(approval_ticket, name, version):
raise PermissionError("Approval not found or invalid")
# Update status
old_status = model_version.status
model_version.status = target_status
model_version.promoted_at = datetime.utcnow()
model_version.promoted_by = user_id
# Log promotion
self._audit_log(
action="MODEL_PROMOTED",
model_name=name,
version=version,
user_id=user_id,
details={
"from_status": old_status.value,
"to_status": target_status.value,
"approval_ticket": approval_ticket
}
)
return model_version
def rollback_model(self,
name: str,
target_version: str,
user_id: str,
reason: str) -> ModelVersion:
"""Rollback to a previous model version."""
# Get current production version
current_prod = self._get_production_version(name)
if not current_prod:
raise ValueError(f"No production version found for {name}")
target = self.models.get(name, {}).get(target_version)
if not target:
raise ValueError(f"Target version {name}:{target_version} not found")
# Check rollback permissions
if not self._check_access(name, target_version, user_id, "rollback"):
raise PermissionError(f"User {user_id} not authorized to rollback {name}")
# Verify target version was previously in production
if target_version not in current_prod.lineage:
raise ValueError(f"Version {target_version} was never in production")
# Demote current production
current_prod.status = ModelStatus.DEPRECATED
current_prod.promoted_at = datetime.utcnow()
current_prod.promoted_by = user_id
# Restore target to production
target.status = ModelStatus.PRODUCTION
target.promoted_at = datetime.utcnow()
target.promoted_by = user_id
# Log rollback
self._audit_log(
action="MODEL_ROLLBACK",
model_name=name,
version=target_version,
user_id=user_id,
details={
"rolled_back_from": current_prod.version_id,
"reason": reason
}
)
return target
def _calculate_checksum(self, data: bytes) -> str:
"""Calculate SHA-256 checksum of data."""
return hashlib.sha256(data).hexdigest()
def _sign_artifact(self, artifact: ModelArtifact, user_id: str) -> str:
"""Sign artifact with user's private key."""
if user_id not in self.signing_keys:
raise ValueError(f"No signing key found for user {user_id}")
private_key = self.signing_keys[user_id]
# Create signable content
content = f"{artifact.artifact_id}:{artifact.checksum}:{artifact.size_bytes}"
# Sign with RSA
signature = private_key.sign(
content.encode(),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return base64.b64encode(signature).decode()
def _verify_signature(self, artifact: ModelArtifact) -> bool:
"""Verify artifact signature."""
# Implementation would verify against stored public key
return True
def _generate_version_id(self, name: str, version: str) -> str:
"""Generate unique version identifier."""
timestamp = datetime.utcnow().strftime("%Y%m%d%H%M%S")
return f"{name}_{version}_{timestamp}"
def _get_current_key_id(self) -> str:
"""Get current encryption key ID."""
return "key_v1"
def _store_model(self, name: str, version: str, data: bytes) -> str:
"""Store model in backend storage."""
path = f"models/{name}/{version}/model.bin"
self.storage.write(path, data)
return path
def _retrieve_model(self, path: str) -> bytes:
"""Retrieve model from backend storage."""
return self.storage.read(path)
def _build_lineage(self, name: str, version: str) -> List[str]:
"""Build version lineage."""
lineage = []
model_versions = self.models.get(name, {})
# Find production versions in lineage
for v_id, v in model_versions.items():
if v.status == ModelStatus.PRODUCTION:
lineage.append(v_id)
return lineage
def _get_production_version(self, name: str) -> Optional[ModelVersion]:
"""Get current production version."""
for version in self.models.get(name, {}).values():
if version.status == ModelStatus.PRODUCTION:
return version
return None
def _check_access(self, name: str, version: str, user_id: str, action: str) -> bool:
"""Check if user has access to perform action."""
# Implementation would check RBAC permissions
return True
def _verify_approval(self, ticket: str, name: str, version: str) -> bool:
"""Verify approval ticket for production promotion."""
# Implementation would verify against approval system
return True
def _audit_log(self, action: str, model_name: str, version: str,
user_id: str, details: Dict):
"""Log audit event."""
event = {
"timestamp": datetime.utcnow().isoformat(),
"action": action,
"model_name": model_name,
"version": version,
"user_id": user_id,
"details": details
}
# Store audit event
print(f"AUDIT: {json.dumps(event)}")
class SecurityError(Exception):
"""Security-related error."""
passModel Integrity Verification
import hashlib
from typing import Dict, List, Tuple
import json
class ModelIntegrityVerifier:
def __init__(self):
self.verification_chain: List[Dict] = []
def create_manifest(self, model_path: str, metadata: Dict) -> Dict:
"""Create integrity manifest for model."""
manifest = {
"manifest_version": "1.0",
"created_at": datetime.utcnow().isoformat(),
"model_info": {
"name": metadata.get("name"),
"version": metadata.get("version"),
"framework": metadata.get("framework")
},
"checksums": {},
"dependencies": [],
"chain_hash": None
}
# Calculate checksums for all model files
manifest["checksums"] = self._calculate_file_checksums(model_path)
# Include dependency checksums
manifest["dependencies"] = self._get_dependency_checksums(
metadata.get("dependencies", {})
)
# Calculate chain hash (merkle root of all checksums)
manifest["chain_hash"] = self._calculate_chain_hash(manifest["checksums"])
return manifest
def _calculate_file_checksums(self, model_path: str) -> Dict[str, str]:
"""Calculate checksums for all files in model."""
import os
checksums = {}
for root, dirs, files in os.walk(model_path):
for file in files:
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, model_path)
with open(file_path, 'rb') as f:
content = f.read()
checksums[relative_path] = {
"sha256": hashlib.sha256(content).hexdigest(),
"sha512": hashlib.sha512(content).hexdigest(),
"size": len(content)
}
return checksums
def _get_dependency_checksums(self, dependencies: Dict[str, str]) -> List[Dict]:
"""Get checksums for model dependencies."""
dep_checksums = []
for package, version in dependencies.items():
dep_checksums.append({
"package": package,
"version": version,
"checksum": self._get_package_checksum(package, version)
})
return dep_checksums
def _get_package_checksum(self, package: str, version: str) -> str:
"""Get checksum for a package version from PyPI or similar."""
# Would query package registry for official checksum
return f"sha256:{package}_{version}"
def _calculate_chain_hash(self, checksums: Dict) -> str:
"""Calculate merkle root of all checksums."""
# Sort checksums for deterministic ordering
sorted_values = [
checksums[k]["sha256"]
for k in sorted(checksums.keys())
]
# Build merkle tree
while len(sorted_values) > 1:
new_level = []
for i in range(0, len(sorted_values), 2):
left = sorted_values[i]
right = sorted_values[i + 1] if i + 1 < len(sorted_values) else left
combined = hashlib.sha256(f"{left}{right}".encode()).hexdigest()
new_level.append(combined)
sorted_values = new_level
return sorted_values[0] if sorted_values else ""
def verify_integrity(self, model_path: str, manifest: Dict) -> Tuple[bool, List[str]]:
"""Verify model integrity against manifest."""
errors = []
# Verify file checksums
current_checksums = self._calculate_file_checksums(model_path)
for file_path, expected in manifest["checksums"].items():
if file_path not in current_checksums:
errors.append(f"Missing file: {file_path}")
continue
actual = current_checksums[file_path]
if actual["sha256"] != expected["sha256"]:
errors.append(f"Checksum mismatch for {file_path}")
if actual["size"] != expected["size"]:
errors.append(f"Size mismatch for {file_path}")
# Check for unexpected files
for file_path in current_checksums:
if file_path not in manifest["checksums"]:
errors.append(f"Unexpected file: {file_path}")
# Verify chain hash
current_chain = self._calculate_chain_hash(current_checksums)
if current_chain != manifest["chain_hash"]:
errors.append("Chain hash mismatch - possible tampering")
return len(errors) == 0, errors
def verify_dependencies(self, manifest: Dict) -> Tuple[bool, List[str]]:
"""Verify installed dependencies match manifest."""
import pkg_resources
errors = []
for dep in manifest.get("dependencies", []):
package = dep["package"]
expected_version = dep["version"]
try:
installed = pkg_resources.get_distribution(package)
if installed.version != expected_version:
errors.append(
f"Dependency version mismatch: {package} "
f"(expected {expected_version}, got {installed.version})"
)
except pkg_resources.DistributionNotFound:
errors.append(f"Missing dependency: {package}")
return len(errors) == 0, errorsModel Access Control
from typing import Set, Dict, List
from enum import Enum
class ModelPermission(Enum):
READ = "read"
WRITE = "write"
PROMOTE = "promote"
ROLLBACK = "rollback"
DELETE = "delete"
ADMIN = "admin"
class ModelAccessControl:
def __init__(self):
self.role_permissions: Dict[str, Set[ModelPermission]] = {
"viewer": {ModelPermission.READ},
"developer": {ModelPermission.READ, ModelPermission.WRITE},
"ml_engineer": {
ModelPermission.READ,
ModelPermission.WRITE,
ModelPermission.PROMOTE
},
"ml_admin": {
ModelPermission.READ,
ModelPermission.WRITE,
ModelPermission.PROMOTE,
ModelPermission.ROLLBACK,
ModelPermission.DELETE,
ModelPermission.ADMIN
}
}
self.user_roles: Dict[str, Dict[str, str]] = {} # user_id -> {model_name: role}
self.model_policies: Dict[str, Dict] = {}
def assign_role(self, user_id: str, model_name: str, role: str):
"""Assign a role to a user for a specific model."""
if role not in self.role_permissions:
raise ValueError(f"Unknown role: {role}")
if user_id not in self.user_roles:
self.user_roles[user_id] = {}
self.user_roles[user_id][model_name] = role
def check_permission(self,
user_id: str,
model_name: str,
permission: ModelPermission) -> bool:
"""Check if user has specific permission for a model."""
user_roles = self.user_roles.get(user_id, {})
# Check model-specific role
role = user_roles.get(model_name)
# Check wildcard role
if not role:
role = user_roles.get("*")
if not role:
return False
allowed_permissions = self.role_permissions.get(role, set())
return permission in allowed_permissions
def set_model_policy(self, model_name: str, policy: Dict):
"""Set security policy for a model."""
self.model_policies[model_name] = {
"require_approval_for_production": policy.get("require_approval", True),
"allowed_environments": policy.get("environments", ["dev", "staging", "prod"]),
"retention_days": policy.get("retention_days", 365),
"encryption_required": policy.get("encryption_required", True),
"signing_required": policy.get("signing_required", True),
"max_versions": policy.get("max_versions", 50)
}
def evaluate_policy(self, model_name: str, action: str, context: Dict) -> Tuple[bool, str]:
"""Evaluate policy for an action."""
policy = self.model_policies.get(model_name, {})
if action == "promote_to_production":
if policy.get("require_approval_for_production", True):
if not context.get("approval_ticket"):
return False, "Production promotion requires approval"
if action == "deploy":
allowed_envs = policy.get("allowed_environments", [])
target_env = context.get("environment")
if target_env not in allowed_envs:
return False, f"Deployment to {target_env} not allowed"
if action == "register":
if policy.get("encryption_required", True):
if not context.get("encrypted", False):
return False, "Model encryption is required"
if policy.get("signing_required", True):
if not context.get("signed", False):
return False, "Model signing is required"
return True, "Policy check passed"
def get_user_permissions(self, user_id: str, model_name: str) -> Set[ModelPermission]:
"""Get all permissions a user has for a model."""
user_roles = self.user_roles.get(user_id, {})
role = user_roles.get(model_name) or user_roles.get("*")
if not role:
return set()
return self.role_permissions.get(role, set())Model Audit Trail
from datetime import datetime
from typing import Dict, List, Optional
import json
import hashlib
class ModelAuditTrail:
def __init__(self, storage_backend):
self.storage = storage_backend
self.events: List[Dict] = []
def log_event(self,
event_type: str,
model_name: str,
version: str,
user_id: str,
details: Dict,
ip_address: Optional[str] = None) -> str:
"""Log an audit event with tamper-evident chaining."""
timestamp = datetime.utcnow()
# Get previous event hash for chain
previous_hash = self._get_last_event_hash()
event = {
"event_id": self._generate_event_id(),
"timestamp": timestamp.isoformat(),
"event_type": event_type,
"model_name": model_name,
"version": version,
"user_id": user_id,
"ip_address": ip_address,
"details": details,
"previous_hash": previous_hash
}
# Calculate event hash for chain integrity
event["event_hash"] = self._calculate_event_hash(event)
# Store event
self.events.append(event)
self._persist_event(event)
return event["event_id"]
def query_events(self,
model_name: Optional[str] = None,
version: Optional[str] = None,
event_type: Optional[str] = None,
user_id: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = 100) -> List[Dict]:
"""Query audit events with filters."""
filtered = self.events
if model_name:
filtered = [e for e in filtered if e["model_name"] == model_name]
if version:
filtered = [e for e in filtered if e["version"] == version]
if event_type:
filtered = [e for e in filtered if e["event_type"] == event_type]
if user_id:
filtered = [e for e in filtered if e["user_id"] == user_id]
if start_time:
filtered = [e for e in filtered
if datetime.fromisoformat(e["timestamp"]) >= start_time]
if end_time:
filtered = [e for e in filtered
if datetime.fromisoformat(e["timestamp"]) <= end_time]
return filtered[:limit]
def verify_chain_integrity(self) -> Tuple[bool, List[str]]:
"""Verify the integrity of the audit chain."""
errors = []
for i, event in enumerate(self.events):
# Verify event hash
calculated_hash = self._calculate_event_hash({
k: v for k, v in event.items()
if k != "event_hash"
})
if calculated_hash != event["event_hash"]:
errors.append(f"Hash mismatch at event {event['event_id']}")
# Verify chain linkage
if i > 0:
if event["previous_hash"] != self.events[i-1]["event_hash"]:
errors.append(f"Chain break at event {event['event_id']}")
return len(errors) == 0, errors
def generate_compliance_report(self,
model_name: str,
start_date: datetime,
end_date: datetime) -> Dict:
"""Generate compliance report for a model."""
events = self.query_events(
model_name=model_name,
start_time=start_date,
end_time=end_date,
limit=10000
)
report = {
"model_name": model_name,
"report_period": {
"start": start_date.isoformat(),
"end": end_date.isoformat()
},
"generated_at": datetime.utcnow().isoformat(),
"summary": {
"total_events": len(events),
"events_by_type": {},
"unique_users": set(),
"versions_affected": set()
},
"events": events,
"chain_integrity": None
}
# Aggregate summary
for event in events:
event_type = event["event_type"]
report["summary"]["events_by_type"][event_type] = \
report["summary"]["events_by_type"].get(event_type, 0) + 1
report["summary"]["unique_users"].add(event["user_id"])
report["summary"]["versions_affected"].add(event["version"])
# Convert sets to lists for JSON
report["summary"]["unique_users"] = list(report["summary"]["unique_users"])
report["summary"]["versions_affected"] = list(report["summary"]["versions_affected"])
# Verify chain integrity
is_valid, errors = self.verify_chain_integrity()
report["chain_integrity"] = {
"valid": is_valid,
"errors": errors
}
return report
def _generate_event_id(self) -> str:
"""Generate unique event ID."""
return f"evt_{datetime.utcnow().strftime('%Y%m%d%H%M%S%f')}"
def _get_last_event_hash(self) -> str:
"""Get hash of the last event in the chain."""
if not self.events:
return "genesis"
return self.events[-1]["event_hash"]
def _calculate_event_hash(self, event: Dict) -> str:
"""Calculate SHA-256 hash of event."""
# Create deterministic string representation
event_str = json.dumps(event, sort_keys=True, default=str)
return hashlib.sha256(event_str.encode()).hexdigest()
def _persist_event(self, event: Dict):
"""Persist event to storage backend."""
# Would write to database or log system
passBest Practices
Model Security
- Encryption: Always encrypt models at rest and in transit
- Signing: Sign models to verify authenticity and prevent tampering
- Checksums: Calculate and verify checksums for all model artifacts
- Access control: Implement RBAC with least privilege principle
Versioning Strategy
- Use semantic versioning for models
- Maintain full lineage for production models
- Implement rollback capabilities with verification
- Archive deprecated models with retention policies
Compliance
- Maintain tamper-evident audit trails
- Generate compliance reports on demand
- Verify chain integrity regularly
- Store audit logs in immutable storage
Related Resources
- Data Versioning for ML: DVC, lakeFS, and Delta Lake — Version data alongside models
- Model Governance: Managing ML Models from Development to Retirement — Full lifecycle management
- MLOps Security: Securing Your ML Pipeline — End-to-end pipeline security
- MLOps Best Practices: Building Production-Ready ML Pipelines — Production pipeline patterns
- What is MLOps? — Complete MLOps overview
Need help securing your ML model lifecycle? DeviDevs builds secure MLOps platforms with built-in versioning, integrity verification, and compliance. Get a free assessment →