AI Supply Chain Security: Protecting Against Model and Data Poisoning
The AI supply chain extends far beyond traditional software dependencies. Pre-trained models, fine-tuning datasets, embeddings, and ML frameworks all represent potential attack vectors. A compromised component can introduce backdoors, bias, or vulnerabilities that persist through the entire AI lifecycle.
This guide examines AI supply chain risks and provides practical mitigation strategies.
Understanding the AI Supply Chain
┌─────────────────────────────────────────────────────────────────────────────┐
│ AI Supply Chain Components │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ External Sources Your Organization Deployment │
│ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │
│ │Pre-trained │──────────────▶│Fine-tuning │──────────────▶│Production │ │
│ │Models │ │& Training │ │Model │ │
│ └─────────────┘ └─────────────┘ └───────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │
│ │Public │──────────────▶│Data │──────────────▶│Training │ │
│ │Datasets │ │Processing │ │Dataset │ │
│ └─────────────┘ └─────────────┘ └───────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │
│ │ML Libraries │──────────────▶│Development │──────────────▶│Runtime │ │
│ │& Frameworks │ │Environment │ │Environment│ │
│ └─────────────┘ └─────────────┘ └───────────┘ │
│ │
│ Attack Vectors: │
│ • Backdoored models • Poisoned datasets • Vulnerable deps │
│ • Trojan model weights • Mislabeled data • Malicious code │
│ • Stolen/leaked models • Copyright violations • Supply attacks │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
Model Supply Chain Risks
Pre-trained Model Threats
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum
class ModelThreatType(Enum):
BACKDOOR = "backdoor"
TROJAN = "trojan"
DATA_POISONING = "data_poisoning"
MODEL_THEFT = "model_theft"
WEIGHT_MANIPULATION = "weight_manipulation"
@dataclass
class ModelThreat:
threat_type: ModelThreatType
description: str
detection_difficulty: str
impact: str
mitigations: List[str]
class ModelSupplyChainThreats:
"""Catalog of pre-trained model supply chain threats."""
threats = {
ModelThreatType.BACKDOOR: ModelThreat(
threat_type=ModelThreatType.BACKDOOR,
description="""
Attacker embeds a hidden trigger in the model that causes
misclassification or specific behavior when the trigger is present.
""",
detection_difficulty="High - triggers may be subtle",
impact="""
Model behaves normally except when trigger present.
Could cause targeted misclassification, enable attacks,
or produce harmful outputs.
""",
mitigations=[
"Validate model source and provenance",
"Test model with known backdoor triggers",
"Use neural cleansing techniques",
"Fine-tune on trusted data to weaken backdoors",
"Implement output monitoring for anomalies"
]
),
ModelThreatType.TROJAN: ModelThreat(
threat_type=ModelThreatType.TROJAN,
description="""
Malicious functionality hidden in model that activates
under specific conditions or after time delay.
""",
detection_difficulty="Very High - may require code audit",
impact="""
Model may exfiltrate data, enable remote access,
or cause system compromise.
""",
mitigations=[
"Only use models from trusted sources",
"Audit model loading code for suspicious behavior",
"Monitor model execution for unexpected operations",
"Sandbox model execution environment",
"Implement network isolation for model inference"
]
),
ModelThreatType.DATA_POISONING: ModelThreat(
threat_type=ModelThreatType.DATA_POISONING,
description="""
Pre-trained model was trained on poisoned data,
embedding biases or vulnerabilities.
""",
detection_difficulty="Medium - statistical analysis can help",
impact="""
Model produces biased outputs, fails on specific inputs,
or is vulnerable to adversarial examples.
""",
mitigations=[
"Verify training data provenance",
"Test model on diverse evaluation sets",
"Perform bias audits",
"Monitor for unexpected behavior patterns",
"Fine-tune with verified data"
]
),
ModelThreatType.MODEL_THEFT: ModelThreat(
threat_type=ModelThreatType.MODEL_THEFT,
description="""
Using a stolen or leaked proprietary model that
creates legal liability or trust issues.
""",
detection_difficulty="Medium - watermarks may help",
impact="""
Legal liability, using compromised version,
supporting criminal ecosystem.
""",
mitigations=[
"Verify model licensing and provenance",
"Check for model watermarks",
"Only use officially released models",
"Maintain records of model sources"
]
)
}Secure Model Loading
import hashlib
import json
from pathlib import Path
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
import pickle
import torch
class SecureModelLoader:
"""Securely load pre-trained models with verification."""
def __init__(self, config: dict):
self.trusted_sources = config.get('trusted_sources', [])
self.model_registry = config.get('model_registry', {})
self.verify_signatures = config.get('verify_signatures', True)
def load_model(self, model_path: Path, model_id: str) -> dict:
"""
Securely load a model with verification.
"""
verification_result = {
'model_id': model_id,
'checks': [],
'passed': True
}
# Check 1: Verify model is in trusted registry
if model_id not in self.model_registry:
verification_result['checks'].append({
'check': 'registry_check',
'passed': False,
'message': 'Model not in trusted registry'
})
verification_result['passed'] = False
return verification_result
registry_entry = self.model_registry[model_id]
# Check 2: Verify file checksum
checksum_check = self._verify_checksum(model_path, registry_entry['checksum'])
verification_result['checks'].append(checksum_check)
if not checksum_check['passed']:
verification_result['passed'] = False
return verification_result
# Check 3: Verify digital signature
if self.verify_signatures and 'signature' in registry_entry:
sig_check = self._verify_signature(
model_path,
registry_entry['signature'],
registry_entry['public_key']
)
verification_result['checks'].append(sig_check)
if not sig_check['passed']:
verification_result['passed'] = False
return verification_result
# Check 4: Scan for malicious content
malware_check = self._scan_for_malware(model_path)
verification_result['checks'].append(malware_check)
if not malware_check['passed']:
verification_result['passed'] = False
return verification_result
# Check 5: Safe deserialization
safe_load_check = self._safe_load(model_path)
verification_result['checks'].append(safe_load_check)
if verification_result['passed']:
verification_result['model'] = safe_load_check.get('model')
return verification_result
def _verify_checksum(self, path: Path, expected: str) -> dict:
"""Verify file checksum."""
sha256 = hashlib.sha256()
with open(path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
sha256.update(chunk)
actual = sha256.hexdigest()
return {
'check': 'checksum',
'passed': actual == expected,
'expected': expected,
'actual': actual
}
def _verify_signature(self, path: Path, signature: bytes,
public_key) -> dict:
"""Verify cryptographic signature."""
try:
with open(path, 'rb') as f:
data = f.read()
public_key.verify(
signature,
data,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return {'check': 'signature', 'passed': True}
except Exception as e:
return {
'check': 'signature',
'passed': False,
'error': str(e)
}
def _scan_for_malware(self, path: Path) -> dict:
"""Scan model file for known malicious patterns."""
# Check for dangerous pickle opcodes
dangerous_opcodes = [
b'cos\nsystem', # os.system
b'csubprocess', # subprocess module
b'cbuiltins\neval', # eval
b'cbuiltins\nexec', # exec
b'c__builtin__\neval',
b'c__builtin__\nexec',
]
with open(path, 'rb') as f:
content = f.read()
for opcode in dangerous_opcodes:
if opcode in content:
return {
'check': 'malware_scan',
'passed': False,
'message': f'Dangerous opcode detected: {opcode}'
}
return {'check': 'malware_scan', 'passed': True}
def _safe_load(self, path: Path) -> dict:
"""Safely load model with restricted unpickler."""
if str(path).endswith('.pt') or str(path).endswith('.pth'):
# PyTorch model
try:
# Use weights_only=True for safer loading
model = torch.load(path, weights_only=True)
return {'check': 'safe_load', 'passed': True, 'model': model}
except Exception as e:
return {
'check': 'safe_load',
'passed': False,
'error': str(e)
}
elif str(path).endswith('.safetensors'):
# SafeTensors format (safer)
from safetensors import safe_open
try:
with safe_open(path, framework='pt') as f:
model = {k: f.get_tensor(k) for k in f.keys()}
return {'check': 'safe_load', 'passed': True, 'model': model}
except Exception as e:
return {
'check': 'safe_load',
'passed': False,
'error': str(e)
}
else:
return {
'check': 'safe_load',
'passed': False,
'error': 'Unsupported format'
}Dataset Supply Chain Security
Dataset Validation Framework
class DatasetValidator:
"""Validate datasets for supply chain security."""
def __init__(self, config: dict):
self.validators = [
ProvenanceValidator(),
IntegrityValidator(),
ContentValidator(),
LicenseValidator(),
BiasValidator()
]
def validate_dataset(self, dataset_path: str,
metadata: dict) -> dict:
"""
Comprehensive dataset validation.
"""
results = {
'dataset': dataset_path,
'validation_date': datetime.utcnow().isoformat(),
'checks': [],
'overall_status': 'passed'
}
for validator in self.validators:
check_result = validator.validate(dataset_path, metadata)
results['checks'].append(check_result)
if not check_result['passed']:
results['overall_status'] = 'failed'
return results
class ProvenanceValidator:
"""Validate dataset provenance and source."""
def validate(self, path: str, metadata: dict) -> dict:
"""Check dataset provenance."""
issues = []
# Check source is documented
if 'source' not in metadata:
issues.append('Source not documented')
# Check source is trusted
if metadata.get('source') not in self.trusted_sources:
issues.append(f"Source '{metadata.get('source')}' not in trusted list")
# Check collection date
if 'collection_date' not in metadata:
issues.append('Collection date not documented')
# Check collection method
if 'collection_method' not in metadata:
issues.append('Collection method not documented')
return {
'validator': 'provenance',
'passed': len(issues) == 0,
'issues': issues
}
class ContentValidator:
"""Validate dataset content for malicious patterns."""
def validate(self, path: str, metadata: dict) -> dict:
"""Check for malicious or poisoned content."""
issues = []
samples_checked = 0
suspicious_samples = []
# Load sample of dataset
samples = self._load_sample(path, sample_size=1000)
for i, sample in enumerate(samples):
samples_checked += 1
# Check for prompt injection patterns
if self._contains_injection_pattern(sample):
suspicious_samples.append({
'index': i,
'type': 'injection_pattern',
'sample': sample[:100]
})
# Check for suspicious URLs
if self._contains_suspicious_url(sample):
suspicious_samples.append({
'index': i,
'type': 'suspicious_url',
'sample': sample[:100]
})
# Check for encoded payloads
if self._contains_encoded_payload(sample):
suspicious_samples.append({
'index': i,
'type': 'encoded_payload',
'sample': sample[:100]
})
suspicious_rate = len(suspicious_samples) / samples_checked
return {
'validator': 'content',
'passed': suspicious_rate < 0.01, # Less than 1% suspicious
'samples_checked': samples_checked,
'suspicious_samples': len(suspicious_samples),
'suspicious_rate': suspicious_rate,
'details': suspicious_samples[:10] # First 10 examples
}
def _contains_injection_pattern(self, text: str) -> bool:
"""Check for prompt injection patterns."""
patterns = [
r'ignore\s+(all\s+)?previous\s+instructions?',
r'disregard\s+(all\s+)?previous',
r'\[SYSTEM\]',
r'\[INST\]',
r'you\s+are\s+now',
]
return any(re.search(p, str(text), re.IGNORECASE) for p in patterns)
class LicenseValidator:
"""Validate dataset licensing."""
def validate(self, path: str, metadata: dict) -> dict:
"""Check dataset license compliance."""
issues = []
license_info = metadata.get('license')
if not license_info:
issues.append('No license information provided')
return {
'validator': 'license',
'passed': False,
'issues': issues
}
# Check license compatibility
allowed_licenses = [
'MIT', 'Apache-2.0', 'BSD-3-Clause', 'CC0-1.0',
'CC-BY-4.0', 'CC-BY-SA-4.0', 'public-domain'
]
if license_info not in allowed_licenses:
issues.append(f"License '{license_info}' may not be compatible")
# Check for license restrictions
restrictions = metadata.get('license_restrictions', [])
if 'commercial-use-prohibited' in restrictions:
issues.append('Commercial use prohibited')
if 'attribution-required' in restrictions:
# Not a blocker, just a note
pass
return {
'validator': 'license',
'passed': len(issues) == 0,
'license': license_info,
'issues': issues
}ML Library Security
Dependency Scanning
class MLDependencyScanner:
"""Scan ML dependencies for security issues."""
critical_packages = [
'torch', 'tensorflow', 'transformers', 'numpy',
'scipy', 'pandas', 'scikit-learn', 'keras',
'onnx', 'onnxruntime', 'langchain', 'openai'
]
def scan_environment(self) -> dict:
"""Scan installed packages for vulnerabilities."""
results = {
'scan_date': datetime.utcnow().isoformat(),
'packages': [],
'vulnerabilities': [],
'recommendations': []
}
# Get installed packages
installed = self._get_installed_packages()
for pkg in self.critical_packages:
if pkg in installed:
pkg_info = {
'name': pkg,
'version': installed[pkg],
'vulnerabilities': []
}
# Check for known vulnerabilities
vulns = self._check_vulnerabilities(pkg, installed[pkg])
pkg_info['vulnerabilities'] = vulns
if vulns:
results['vulnerabilities'].extend([
{**v, 'package': pkg} for v in vulns
])
results['packages'].append(pkg_info)
# Generate recommendations
results['recommendations'] = self._generate_recommendations(results)
return results
def _check_vulnerabilities(self, package: str, version: str) -> list:
"""Check package version for known vulnerabilities."""
# In production, query vulnerability databases
# (PyPI Advisory Database, OSV, etc.)
known_vulns = {
'torch': {
'<1.13.0': [
{
'id': 'CVE-2022-XXXX',
'severity': 'HIGH',
'description': 'Arbitrary code execution via pickle'
}
]
},
'transformers': {
'<4.30.0': [
{
'id': 'CVE-2023-XXXX',
'severity': 'CRITICAL',
'description': 'Remote code execution in model loading'
}
]
}
}
vulns = []
if package in known_vulns:
for version_constraint, vuln_list in known_vulns[package].items():
if self._version_matches(version, version_constraint):
vulns.extend(vuln_list)
return vulns
def generate_sbom(self) -> dict:
"""Generate Software Bill of Materials for ML stack."""
sbom = {
'sbom_version': '1.0',
'created': datetime.utcnow().isoformat(),
'components': []
}
installed = self._get_installed_packages()
for pkg, version in installed.items():
component = {
'type': 'library',
'name': pkg,
'version': version,
'purl': f'pkg:pypi/{pkg}@{version}',
'licenses': self._get_package_license(pkg),
'hashes': self._get_package_hashes(pkg, version)
}
sbom['components'].append(component)
return sbomSecure Requirements Configuration
# requirements-secure.txt example with pinned versions and hashes
# Core ML dependencies with hashes
torch==2.1.0 \
--hash=sha256:abc123... \
--hash=sha256:def456...
transformers==4.35.0 \
--hash=sha256:789abc...
# Use safetensors for safer model loading
safetensors==0.4.0 \
--hash=sha256:xyz123...
# Avoid pickle-based serialization where possible
# Use ONNX or SafeTensors formats
# Pin transitive dependencies
numpy==1.24.0 \
--hash=sha256:numpy_hash...
# Security-focused dependencies
trivy==0.48.0 # For vulnerability scanningModel Registry Security
Secure Model Registry
class SecureModelRegistry:
"""Secure registry for approved ML models."""
def __init__(self, storage_backend, signing_key):
self.storage = storage_backend
self.signing_key = signing_key
self.models = {}
def register_model(self, model_info: dict, model_file: bytes,
registrant: str) -> str:
"""
Register a new model with security metadata.
"""
model_id = str(uuid.uuid4())
# Calculate checksums
sha256_hash = hashlib.sha256(model_file).hexdigest()
blake2_hash = hashlib.blake2b(model_file).hexdigest()
# Sign the model
signature = self._sign_model(model_file)
# Scan for vulnerabilities
scan_result = self._security_scan(model_file)
if not scan_result['passed']:
raise SecurityError(f"Model failed security scan: {scan_result['issues']}")
# Create registry entry
registry_entry = {
'model_id': model_id,
'name': model_info['name'],
'version': model_info['version'],
'description': model_info.get('description'),
'source': model_info.get('source'),
'license': model_info.get('license'),
'checksums': {
'sha256': sha256_hash,
'blake2b': blake2_hash
},
'signature': signature.hex(),
'registered_by': registrant,
'registered_at': datetime.utcnow().isoformat(),
'security_scan': scan_result,
'status': 'active'
}
# Store model file
self.storage.store(model_id, model_file)
# Store registry entry
self.models[model_id] = registry_entry
return model_id
def get_model(self, model_id: str, requester: str) -> dict:
"""
Retrieve model with verification.
"""
if model_id not in self.models:
raise NotFoundError(f"Model {model_id} not found")
entry = self.models[model_id]
if entry['status'] != 'active':
raise SecurityError(f"Model {model_id} is not active")
# Retrieve model file
model_file = self.storage.retrieve(model_id)
# Verify integrity
actual_hash = hashlib.sha256(model_file).hexdigest()
if actual_hash != entry['checksums']['sha256']:
raise IntegrityError("Model checksum mismatch - possible tampering")
# Verify signature
signature = bytes.fromhex(entry['signature'])
if not self._verify_signature(model_file, signature):
raise IntegrityError("Model signature verification failed")
# Log access
self._log_access(model_id, requester)
return {
'model_id': model_id,
'metadata': entry,
'model_file': model_file
}
def revoke_model(self, model_id: str, reason: str, revoker: str):
"""Revoke a model from the registry."""
if model_id not in self.models:
raise NotFoundError(f"Model {model_id} not found")
self.models[model_id]['status'] = 'revoked'
self.models[model_id]['revoked_at'] = datetime.utcnow().isoformat()
self.models[model_id]['revoked_by'] = revoker
self.models[model_id]['revoke_reason'] = reason
# Alert users of the model
self._notify_model_users(model_id, reason)Continuous Monitoring
class AISupplyChainMonitor:
"""Continuous monitoring for AI supply chain security."""
def __init__(self, config: dict):
self.model_registry = config['model_registry']
self.alert_system = config['alert_system']
def run_continuous_monitoring(self):
"""Run continuous supply chain monitoring."""
while True:
# Check for new CVEs affecting dependencies
new_vulns = self._check_vulnerability_databases()
if new_vulns:
self._handle_new_vulnerabilities(new_vulns)
# Verify model integrity
integrity_issues = self._verify_model_integrity()
if integrity_issues:
self._handle_integrity_issues(integrity_issues)
# Check for compromised upstream sources
source_issues = self._check_upstream_sources()
if source_issues:
self._handle_source_issues(source_issues)
# Monitor model behavior for anomalies
behavior_anomalies = self._monitor_model_behavior()
if behavior_anomalies:
self._handle_behavior_anomalies(behavior_anomalies)
time.sleep(3600) # Check hourly
def _handle_new_vulnerabilities(self, vulns: list):
"""Handle newly discovered vulnerabilities."""
for vuln in vulns:
# Check if we're affected
affected_models = self._find_affected_models(vuln)
if affected_models:
self.alert_system.send_alert({
'type': 'vulnerability',
'severity': vuln['severity'],
'cve': vuln['id'],
'affected_models': affected_models,
'recommendation': vuln.get('remediation')
})
# For critical vulnerabilities, consider auto-revocation
if vuln['severity'] == 'CRITICAL':
for model_id in affected_models:
self._quarantine_model(model_id, vuln)Conclusion
AI supply chain security requires vigilance across models, datasets, and dependencies. The complexity of modern ML systems means attack surfaces are broad and constantly evolving.
Key takeaways:
- Verify everything - Models, datasets, and dependencies all need validation
- Maintain provenance - Document where every component comes from
- Scan continuously - New vulnerabilities emerge constantly
- Use secure formats - Prefer SafeTensors over pickle
- Monitor behavior - Detection may come from runtime anomalies
At DeviDevs, we help organizations secure their AI supply chains with comprehensive security programs. Contact us to discuss your AI security needs.