AI supply chain security protects the entire machine learning lifecycle from data collection through model deployment. This guide covers security implementations for protecting ML pipelines.
Model Provenance Tracking
Implement comprehensive model lineage tracking:
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Optional, Any
import hashlib
import json
@dataclass
class DatasetInfo:
dataset_id: str
name: str
version: str
source: str
hash: str
size_bytes: int
row_count: int
created_at: datetime
schema: Dict
transformations: List[str]
@dataclass
class TrainingConfig:
framework: str
framework_version: str
hyperparameters: Dict
random_seed: int
epochs: int
batch_size: int
optimizer: str
learning_rate: float
@dataclass
class ModelArtifact:
model_id: str
name: str
version: str
model_hash: str
file_path: str
size_bytes: int
format: str
created_at: datetime
signature: Optional[str] = None
@dataclass
class ModelProvenance:
model_id: str
model_name: str
version: str
created_at: datetime
created_by: str
datasets: List[DatasetInfo]
training_config: TrainingConfig
dependencies: Dict[str, str]
environment: Dict
metrics: Dict[str, float]
artifact: ModelArtifact
parent_model: Optional[str] = None
attestations: List[Dict] = field(default_factory=list)
class ProvenanceTracker:
def __init__(self, storage_backend):
self.storage = storage_backend
self.current_provenance: Optional[ModelProvenance] = None
def start_training_run(
self,
model_name: str,
version: str,
created_by: str
) -> str:
"""Start tracking a new training run."""
model_id = self._generate_model_id(model_name, version)
self.current_provenance = ModelProvenance(
model_id=model_id,
model_name=model_name,
version=version,
created_at=datetime.utcnow(),
created_by=created_by,
datasets=[],
training_config=None,
dependencies={},
environment={},
metrics={},
artifact=None
)
# Capture environment
self.current_provenance.environment = self._capture_environment()
self.current_provenance.dependencies = self._capture_dependencies()
return model_id
def _generate_model_id(self, name: str, version: str) -> str:
timestamp = datetime.utcnow().isoformat()
content = f"{name}:{version}:{timestamp}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _capture_environment(self) -> Dict:
import platform
import os
return {
'platform': platform.platform(),
'python_version': platform.python_version(),
'hostname': platform.node(),
'cpu_count': os.cpu_count(),
'cuda_available': self._check_cuda(),
'timestamp': datetime.utcnow().isoformat()
}
def _check_cuda(self) -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def _capture_dependencies(self) -> Dict[str, str]:
import pkg_resources
return {
pkg.key: pkg.version
for pkg in pkg_resources.working_set
}
def register_dataset(
self,
name: str,
version: str,
source: str,
data_path: str,
transformations: List[str] = None
) -> DatasetInfo:
"""Register a dataset used in training."""
dataset_hash = self._hash_file(data_path)
file_size = self._get_file_size(data_path)
row_count = self._count_rows(data_path)
schema = self._infer_schema(data_path)
dataset = DatasetInfo(
dataset_id=self._generate_dataset_id(name, version),
name=name,
version=version,
source=source,
hash=dataset_hash,
size_bytes=file_size,
row_count=row_count,
created_at=datetime.utcnow(),
schema=schema,
transformations=transformations or []
)
if self.current_provenance:
self.current_provenance.datasets.append(dataset)
return dataset
def _generate_dataset_id(self, name: str, version: str) -> str:
return hashlib.md5(f"{name}:{version}".encode()).hexdigest()[:12]
def _hash_file(self, file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
sha256.update(chunk)
return sha256.hexdigest()
def _get_file_size(self, file_path: str) -> int:
import os
return os.path.getsize(file_path)
def _count_rows(self, file_path: str) -> int:
if file_path.endswith('.csv'):
import pandas as pd
return len(pd.read_csv(file_path))
return -1
def _infer_schema(self, file_path: str) -> Dict:
if file_path.endswith('.csv'):
import pandas as pd
df = pd.read_csv(file_path, nrows=100)
return {col: str(dtype) for col, dtype in df.dtypes.items()}
return {}
def set_training_config(self, config: TrainingConfig):
"""Set training configuration."""
if self.current_provenance:
self.current_provenance.training_config = config
def log_metrics(self, metrics: Dict[str, float]):
"""Log training/evaluation metrics."""
if self.current_provenance:
self.current_provenance.metrics.update(metrics)
def register_model_artifact(
self,
model_path: str,
model_format: str = 'pytorch'
) -> ModelArtifact:
"""Register the trained model artifact."""
model_hash = self._hash_file(model_path)
artifact = ModelArtifact(
model_id=self.current_provenance.model_id,
name=self.current_provenance.model_name,
version=self.current_provenance.version,
model_hash=model_hash,
file_path=model_path,
size_bytes=self._get_file_size(model_path),
format=model_format,
created_at=datetime.utcnow()
)
if self.current_provenance:
self.current_provenance.artifact = artifact
return artifact
def sign_model(self, private_key_path: str):
"""Sign the model artifact for integrity verification."""
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
if not self.current_provenance or not self.current_provenance.artifact:
raise ValueError("No model artifact to sign")
# Load private key
with open(private_key_path, 'rb') as f:
private_key = serialization.load_pem_private_key(f.read(), password=None)
# Sign the model hash
model_hash = self.current_provenance.artifact.model_hash.encode()
signature = private_key.sign(
model_hash,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
import base64
self.current_provenance.artifact.signature = base64.b64encode(signature).decode()
def add_attestation(
self,
attestation_type: str,
attester: str,
claims: Dict
):
"""Add an attestation to the provenance record."""
attestation = {
'type': attestation_type,
'attester': attester,
'claims': claims,
'timestamp': datetime.utcnow().isoformat()
}
if self.current_provenance:
self.current_provenance.attestations.append(attestation)
def finalize_and_store(self) -> Dict:
"""Finalize and store the provenance record."""
if not self.current_provenance:
raise ValueError("No active provenance record")
provenance_dict = self._to_dict(self.current_provenance)
self.storage.store(self.current_provenance.model_id, provenance_dict)
return provenance_dict
def _to_dict(self, provenance: ModelProvenance) -> Dict:
return {
'model_id': provenance.model_id,
'model_name': provenance.model_name,
'version': provenance.version,
'created_at': provenance.created_at.isoformat(),
'created_by': provenance.created_by,
'datasets': [
{
'dataset_id': d.dataset_id,
'name': d.name,
'version': d.version,
'source': d.source,
'hash': d.hash,
'row_count': d.row_count
}
for d in provenance.datasets
],
'training_config': {
'framework': provenance.training_config.framework,
'hyperparameters': provenance.training_config.hyperparameters,
'epochs': provenance.training_config.epochs
} if provenance.training_config else None,
'dependencies': provenance.dependencies,
'environment': provenance.environment,
'metrics': provenance.metrics,
'artifact': {
'model_hash': provenance.artifact.model_hash,
'signature': provenance.artifact.signature,
'format': provenance.artifact.format
} if provenance.artifact else None,
'attestations': provenance.attestations
}
def verify_model(self, model_id: str, model_path: str, public_key_path: str) -> Dict:
"""Verify model integrity against provenance record."""
provenance = self.storage.get(model_id)
if not provenance:
return {'valid': False, 'error': 'Provenance record not found'}
# Verify hash
current_hash = self._hash_file(model_path)
expected_hash = provenance['artifact']['model_hash']
if current_hash != expected_hash:
return {
'valid': False,
'error': 'Model hash mismatch',
'expected': expected_hash,
'actual': current_hash
}
# Verify signature if present
if provenance['artifact'].get('signature'):
signature_valid = self._verify_signature(
current_hash,
provenance['artifact']['signature'],
public_key_path
)
if not signature_valid:
return {'valid': False, 'error': 'Invalid signature'}
return {
'valid': True,
'model_id': model_id,
'verified_at': datetime.utcnow().isoformat()
}
def _verify_signature(self, model_hash: str, signature_b64: str, public_key_path: str) -> bool:
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
import base64
with open(public_key_path, 'rb') as f:
public_key = serialization.load_pem_public_key(f.read())
try:
signature = base64.b64decode(signature_b64)
public_key.verify(
signature,
model_hash.encode(),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return True
except Exception:
return FalseML Dependency Scanner
Scan ML dependencies for vulnerabilities:
from dataclasses import dataclass
from typing import List, Dict, Optional
import subprocess
import json
@dataclass
class VulnerabilityFinding:
package: str
installed_version: str
vulnerable_versions: str
severity: str
cve_id: Optional[str]
description: str
fix_version: Optional[str]
class MLDependencyScanner:
def __init__(self):
self.ml_packages = [
'torch', 'tensorflow', 'keras', 'scikit-learn', 'numpy',
'pandas', 'scipy', 'transformers', 'huggingface-hub',
'onnx', 'onnxruntime', 'mlflow', 'wandb', 'pytorch-lightning'
]
self.known_vulns = self._load_known_vulnerabilities()
def _load_known_vulnerabilities(self) -> Dict:
# In production, fetch from vulnerability database
return {
'torch': [
{
'versions': '<1.9.0',
'severity': 'HIGH',
'cve': 'CVE-2022-XXXXX',
'description': 'Arbitrary code execution via pickle',
'fix_version': '1.9.0'
}
],
'tensorflow': [
{
'versions': '<2.8.4',
'severity': 'CRITICAL',
'cve': 'CVE-2022-35959',
'description': 'CHECK failure in FractionalMaxPoolGrad',
'fix_version': '2.8.4'
}
],
'numpy': [
{
'versions': '<1.22.0',
'severity': 'HIGH',
'cve': 'CVE-2021-41495',
'description': 'NULL pointer dereference in numpy.sort',
'fix_version': '1.22.0'
}
],
'pillow': [
{
'versions': '<9.0.0',
'severity': 'HIGH',
'cve': 'CVE-2022-22815',
'description': 'Path traversal vulnerability',
'fix_version': '9.0.0'
}
]
}
def scan_environment(self) -> List[VulnerabilityFinding]:
"""Scan installed packages for vulnerabilities."""
import pkg_resources
findings = []
installed = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
for package, vulns in self.known_vulns.items():
if package.lower() not in installed:
continue
installed_version = installed[package.lower()]
for vuln in vulns:
if self._version_in_range(installed_version, vuln['versions']):
findings.append(VulnerabilityFinding(
package=package,
installed_version=installed_version,
vulnerable_versions=vuln['versions'],
severity=vuln['severity'],
cve_id=vuln.get('cve'),
description=vuln['description'],
fix_version=vuln.get('fix_version')
))
return findings
def _version_in_range(self, version: str, range_spec: str) -> bool:
from packaging import version as pkg_version
installed = pkg_version.parse(version)
if range_spec.startswith('<'):
return installed < pkg_version.parse(range_spec[1:])
elif range_spec.startswith('<='):
return installed <= pkg_version.parse(range_spec[2:])
elif range_spec.startswith('>='):
return installed >= pkg_version.parse(range_spec[2:])
elif range_spec.startswith('>'):
return installed > pkg_version.parse(range_spec[1:])
elif range_spec.startswith('=='):
return installed == pkg_version.parse(range_spec[2:])
return False
def scan_requirements_file(self, file_path: str) -> List[VulnerabilityFinding]:
"""Scan requirements file for vulnerabilities."""
findings = []
with open(file_path, 'r') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
continue
# Parse package==version
if '==' in line:
package, version = line.split('==')
package = package.strip()
version = version.strip()
if package.lower() in self.known_vulns:
for vuln in self.known_vulns[package.lower()]:
if self._version_in_range(version, vuln['versions']):
findings.append(VulnerabilityFinding(
package=package,
installed_version=version,
vulnerable_versions=vuln['versions'],
severity=vuln['severity'],
cve_id=vuln.get('cve'),
description=vuln['description'],
fix_version=vuln.get('fix_version')
))
return findings
def scan_model_dependencies(self, model_path: str) -> List[VulnerabilityFinding]:
"""Scan dependencies embedded in model files."""
findings = []
# For PyTorch models
if model_path.endswith('.pt') or model_path.endswith('.pth'):
try:
import torch
checkpoint = torch.load(model_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'dependencies' in checkpoint:
deps = checkpoint['dependencies']
for package, version in deps.items():
if package.lower() in self.known_vulns:
for vuln in self.known_vulns[package.lower()]:
if self._version_in_range(version, vuln['versions']):
findings.append(VulnerabilityFinding(
package=package,
installed_version=version,
vulnerable_versions=vuln['versions'],
severity=vuln['severity'],
cve_id=vuln.get('cve'),
description=vuln['description'],
fix_version=vuln.get('fix_version')
))
except Exception as e:
print(f"Failed to scan model: {e}")
return findings
def generate_report(self, findings: List[VulnerabilityFinding]) -> Dict:
"""Generate vulnerability report."""
severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
for f in findings:
severity_counts[f.severity] = severity_counts.get(f.severity, 0) + 1
return {
'scan_time': datetime.utcnow().isoformat(),
'summary': {
'total_vulnerabilities': len(findings),
'by_severity': severity_counts
},
'findings': [
{
'package': f.package,
'installed_version': f.installed_version,
'severity': f.severity,
'cve': f.cve_id,
'description': f.description,
'fix_version': f.fix_version
}
for f in sorted(findings, key=lambda x: {'CRITICAL': 0, 'HIGH': 1, 'MEDIUM': 2, 'LOW': 3}.get(x.severity, 4))
]
}Training Data Integrity
Protect training data integrity:
from dataclasses import dataclass
from typing import List, Dict, Optional
import hashlib
import json
@dataclass
class DataIntegrityCheck:
check_type: str
passed: bool
details: Dict
timestamp: datetime
class TrainingDataIntegrity:
def __init__(self):
self.checks: List[DataIntegrityCheck] = []
def verify_data_source(self, source_config: Dict) -> DataIntegrityCheck:
"""Verify data source authenticity."""
check = DataIntegrityCheck(
check_type='source_verification',
passed=True,
details={},
timestamp=datetime.utcnow()
)
# Verify source URL/path
source_url = source_config.get('url') or source_config.get('path')
if not source_url:
check.passed = False
check.details['error'] = 'No source URL or path specified'
return check
# Check if source is from trusted registry
trusted_sources = [
'huggingface.co',
'kaggle.com',
's3.amazonaws.com',
'storage.googleapis.com'
]
if any(ts in source_url for ts in trusted_sources):
check.details['source_trusted'] = True
else:
check.details['source_trusted'] = False
check.details['warning'] = 'Source not from trusted registry'
self.checks.append(check)
return check
def verify_checksum(self, file_path: str, expected_hash: str, algorithm: str = 'sha256') -> DataIntegrityCheck:
"""Verify file checksum."""
if algorithm == 'sha256':
hasher = hashlib.sha256()
elif algorithm == 'md5':
hasher = hashlib.md5()
else:
raise ValueError(f"Unsupported algorithm: {algorithm}")
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
hasher.update(chunk)
actual_hash = hasher.hexdigest()
check = DataIntegrityCheck(
check_type='checksum_verification',
passed=actual_hash == expected_hash,
details={
'algorithm': algorithm,
'expected': expected_hash,
'actual': actual_hash
},
timestamp=datetime.utcnow()
)
self.checks.append(check)
return check
def detect_data_poisoning(self, dataset, baseline_stats: Dict = None) -> DataIntegrityCheck:
"""Detect potential data poisoning."""
import numpy as np
check = DataIntegrityCheck(
check_type='poisoning_detection',
passed=True,
details={'anomalies': []},
timestamp=datetime.utcnow()
)
# Statistical analysis
if hasattr(dataset, 'shape'):
data = dataset
else:
data = np.array(dataset)
current_stats = {
'mean': float(np.mean(data)),
'std': float(np.std(data)),
'min': float(np.min(data)),
'max': float(np.max(data))
}
check.details['current_stats'] = current_stats
if baseline_stats:
# Compare with baseline
mean_shift = abs(current_stats['mean'] - baseline_stats['mean'])
std_change = abs(current_stats['std'] - baseline_stats['std'])
if mean_shift > baseline_stats['std'] * 2:
check.details['anomalies'].append({
'type': 'mean_shift',
'magnitude': mean_shift,
'threshold': baseline_stats['std'] * 2
})
check.passed = False
if std_change > baseline_stats['std']:
check.details['anomalies'].append({
'type': 'variance_change',
'magnitude': std_change,
'threshold': baseline_stats['std']
})
check.passed = False
# Check for outliers
outlier_threshold = 3
z_scores = np.abs((data - np.mean(data)) / np.std(data))
outlier_ratio = np.sum(z_scores > outlier_threshold) / len(data)
if outlier_ratio > 0.01: # More than 1% outliers
check.details['anomalies'].append({
'type': 'high_outlier_ratio',
'ratio': float(outlier_ratio),
'threshold': 0.01
})
check.passed = False
self.checks.append(check)
return check
def verify_label_integrity(self, labels, expected_distribution: Dict = None) -> DataIntegrityCheck:
"""Verify label distribution integrity."""
import numpy as np
from collections import Counter
check = DataIntegrityCheck(
check_type='label_integrity',
passed=True,
details={},
timestamp=datetime.utcnow()
)
label_counts = Counter(labels)
total = sum(label_counts.values())
distribution = {k: v/total for k, v in label_counts.items()}
check.details['current_distribution'] = distribution
if expected_distribution:
for label, expected_ratio in expected_distribution.items():
actual_ratio = distribution.get(label, 0)
if abs(actual_ratio - expected_ratio) > 0.1: # 10% tolerance
check.details.setdefault('deviations', []).append({
'label': label,
'expected': expected_ratio,
'actual': actual_ratio
})
check.passed = False
self.checks.append(check)
return check
def generate_integrity_report(self) -> Dict:
"""Generate comprehensive integrity report."""
return {
'report_time': datetime.utcnow().isoformat(),
'overall_status': 'PASS' if all(c.passed for c in self.checks) else 'FAIL',
'checks': [
{
'type': c.check_type,
'passed': c.passed,
'details': c.details,
'timestamp': c.timestamp.isoformat()
}
for c in self.checks
]
}Conclusion
AI supply chain security requires protecting every stage of the ML lifecycle. Implement provenance tracking to maintain model lineage, scan dependencies for vulnerabilities, and verify training data integrity. Use model signing for deployment verification. Remember that AI supply chain attacks can be subtle - continuous monitoring and regular security assessments are essential for maintaining a secure ML pipeline.
Related Resources
- MLOps Security: Securing Your ML Pipeline — End-to-end pipeline security patterns
- Data Versioning for ML: DVC, lakeFS, and Delta Lake — Version and verify training data
- Model Governance: Managing ML Models from Development to Retirement — Audit trails and compliance
- ML CI/CD: Continuous Integration and Deployment for Machine Learning — Secure ML deployment pipelines
- What is MLOps? — Complete MLOps overview
Need help securing your ML supply chain? DeviDevs builds secure MLOps platforms with provenance tracking and integrity verification. Get a free assessment →