AI Security

AI Supply Chain Security: Protecting Against Model and Data Poisoning

DeviDevs Team
12 min read
#supply-chain#ai-security#ml-security#model-security#dependency-security

AI Supply Chain Security: Protecting Against Model and Data Poisoning

The AI supply chain extends far beyond traditional software dependencies. Pre-trained models, fine-tuning datasets, embeddings, and ML frameworks all represent potential attack vectors. A compromised component can introduce backdoors, bias, or vulnerabilities that persist through the entire AI lifecycle.

This guide examines AI supply chain risks and provides practical mitigation strategies.

Understanding the AI Supply Chain

┌─────────────────────────────────────────────────────────────────────────────┐
│                        AI Supply Chain Components                           │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  External Sources                Your Organization              Deployment  │
│  ┌─────────────┐               ┌─────────────┐               ┌───────────┐ │
│  │Pre-trained  │──────────────▶│Fine-tuning  │──────────────▶│Production │ │
│  │Models       │               │& Training   │               │Model      │ │
│  └─────────────┘               └─────────────┘               └───────────┘ │
│  ┌─────────────┐               ┌─────────────┐               ┌───────────┐ │
│  │Public       │──────────────▶│Data         │──────────────▶│Training   │ │
│  │Datasets     │               │Processing   │               │Dataset    │ │
│  └─────────────┘               └─────────────┘               └───────────┘ │
│  ┌─────────────┐               ┌─────────────┐               ┌───────────┐ │
│  │ML Libraries │──────────────▶│Development  │──────────────▶│Runtime    │ │
│  │& Frameworks │               │Environment  │               │Environment│ │
│  └─────────────┘               └─────────────┘               └───────────┘ │
│                                                                             │
│  Attack Vectors:                                                            │
│  • Backdoored models          • Poisoned datasets        • Vulnerable deps │
│  • Trojan model weights       • Mislabeled data          • Malicious code  │
│  • Stolen/leaked models       • Copyright violations     • Supply attacks  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Model Supply Chain Risks

Pre-trained Model Threats

from dataclasses import dataclass
from typing import List, Optional
from enum import Enum
 
class ModelThreatType(Enum):
    BACKDOOR = "backdoor"
    TROJAN = "trojan"
    DATA_POISONING = "data_poisoning"
    MODEL_THEFT = "model_theft"
    WEIGHT_MANIPULATION = "weight_manipulation"
 
@dataclass
class ModelThreat:
    threat_type: ModelThreatType
    description: str
    detection_difficulty: str
    impact: str
    mitigations: List[str]
 
class ModelSupplyChainThreats:
    """Catalog of pre-trained model supply chain threats."""
 
    threats = {
        ModelThreatType.BACKDOOR: ModelThreat(
            threat_type=ModelThreatType.BACKDOOR,
            description="""
                Attacker embeds a hidden trigger in the model that causes
                misclassification or specific behavior when the trigger is present.
            """,
            detection_difficulty="High - triggers may be subtle",
            impact="""
                Model behaves normally except when trigger present.
                Could cause targeted misclassification, enable attacks,
                or produce harmful outputs.
            """,
            mitigations=[
                "Validate model source and provenance",
                "Test model with known backdoor triggers",
                "Use neural cleansing techniques",
                "Fine-tune on trusted data to weaken backdoors",
                "Implement output monitoring for anomalies"
            ]
        ),
 
        ModelThreatType.TROJAN: ModelThreat(
            threat_type=ModelThreatType.TROJAN,
            description="""
                Malicious functionality hidden in model that activates
                under specific conditions or after time delay.
            """,
            detection_difficulty="Very High - may require code audit",
            impact="""
                Model may exfiltrate data, enable remote access,
                or cause system compromise.
            """,
            mitigations=[
                "Only use models from trusted sources",
                "Audit model loading code for suspicious behavior",
                "Monitor model execution for unexpected operations",
                "Sandbox model execution environment",
                "Implement network isolation for model inference"
            ]
        ),
 
        ModelThreatType.DATA_POISONING: ModelThreat(
            threat_type=ModelThreatType.DATA_POISONING,
            description="""
                Pre-trained model was trained on poisoned data,
                embedding biases or vulnerabilities.
            """,
            detection_difficulty="Medium - statistical analysis can help",
            impact="""
                Model produces biased outputs, fails on specific inputs,
                or is vulnerable to adversarial examples.
            """,
            mitigations=[
                "Verify training data provenance",
                "Test model on diverse evaluation sets",
                "Perform bias audits",
                "Monitor for unexpected behavior patterns",
                "Fine-tune with verified data"
            ]
        ),
 
        ModelThreatType.MODEL_THEFT: ModelThreat(
            threat_type=ModelThreatType.MODEL_THEFT,
            description="""
                Using a stolen or leaked proprietary model that
                creates legal liability or trust issues.
            """,
            detection_difficulty="Medium - watermarks may help",
            impact="""
                Legal liability, using compromised version,
                supporting criminal ecosystem.
            """,
            mitigations=[
                "Verify model licensing and provenance",
                "Check for model watermarks",
                "Only use officially released models",
                "Maintain records of model sources"
            ]
        )
    }

Secure Model Loading

import hashlib
import json
from pathlib import Path
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
import pickle
import torch
 
class SecureModelLoader:
    """Securely load pre-trained models with verification."""
 
    def __init__(self, config: dict):
        self.trusted_sources = config.get('trusted_sources', [])
        self.model_registry = config.get('model_registry', {})
        self.verify_signatures = config.get('verify_signatures', True)
 
    def load_model(self, model_path: Path, model_id: str) -> dict:
        """
        Securely load a model with verification.
        """
 
        verification_result = {
            'model_id': model_id,
            'checks': [],
            'passed': True
        }
 
        # Check 1: Verify model is in trusted registry
        if model_id not in self.model_registry:
            verification_result['checks'].append({
                'check': 'registry_check',
                'passed': False,
                'message': 'Model not in trusted registry'
            })
            verification_result['passed'] = False
            return verification_result
 
        registry_entry = self.model_registry[model_id]
 
        # Check 2: Verify file checksum
        checksum_check = self._verify_checksum(model_path, registry_entry['checksum'])
        verification_result['checks'].append(checksum_check)
        if not checksum_check['passed']:
            verification_result['passed'] = False
            return verification_result
 
        # Check 3: Verify digital signature
        if self.verify_signatures and 'signature' in registry_entry:
            sig_check = self._verify_signature(
                model_path,
                registry_entry['signature'],
                registry_entry['public_key']
            )
            verification_result['checks'].append(sig_check)
            if not sig_check['passed']:
                verification_result['passed'] = False
                return verification_result
 
        # Check 4: Scan for malicious content
        malware_check = self._scan_for_malware(model_path)
        verification_result['checks'].append(malware_check)
        if not malware_check['passed']:
            verification_result['passed'] = False
            return verification_result
 
        # Check 5: Safe deserialization
        safe_load_check = self._safe_load(model_path)
        verification_result['checks'].append(safe_load_check)
 
        if verification_result['passed']:
            verification_result['model'] = safe_load_check.get('model')
 
        return verification_result
 
    def _verify_checksum(self, path: Path, expected: str) -> dict:
        """Verify file checksum."""
        sha256 = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
 
        actual = sha256.hexdigest()
        return {
            'check': 'checksum',
            'passed': actual == expected,
            'expected': expected,
            'actual': actual
        }
 
    def _verify_signature(self, path: Path, signature: bytes,
                         public_key) -> dict:
        """Verify cryptographic signature."""
        try:
            with open(path, 'rb') as f:
                data = f.read()
 
            public_key.verify(
                signature,
                data,
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return {'check': 'signature', 'passed': True}
 
        except Exception as e:
            return {
                'check': 'signature',
                'passed': False,
                'error': str(e)
            }
 
    def _scan_for_malware(self, path: Path) -> dict:
        """Scan model file for known malicious patterns."""
 
        # Check for dangerous pickle opcodes
        dangerous_opcodes = [
            b'cos\nsystem',  # os.system
            b'csubprocess',   # subprocess module
            b'cbuiltins\neval', # eval
            b'cbuiltins\nexec', # exec
            b'c__builtin__\neval',
            b'c__builtin__\nexec',
        ]
 
        with open(path, 'rb') as f:
            content = f.read()
 
        for opcode in dangerous_opcodes:
            if opcode in content:
                return {
                    'check': 'malware_scan',
                    'passed': False,
                    'message': f'Dangerous opcode detected: {opcode}'
                }
 
        return {'check': 'malware_scan', 'passed': True}
 
    def _safe_load(self, path: Path) -> dict:
        """Safely load model with restricted unpickler."""
 
        if str(path).endswith('.pt') or str(path).endswith('.pth'):
            # PyTorch model
            try:
                # Use weights_only=True for safer loading
                model = torch.load(path, weights_only=True)
                return {'check': 'safe_load', 'passed': True, 'model': model}
            except Exception as e:
                return {
                    'check': 'safe_load',
                    'passed': False,
                    'error': str(e)
                }
 
        elif str(path).endswith('.safetensors'):
            # SafeTensors format (safer)
            from safetensors import safe_open
            try:
                with safe_open(path, framework='pt') as f:
                    model = {k: f.get_tensor(k) for k in f.keys()}
                return {'check': 'safe_load', 'passed': True, 'model': model}
            except Exception as e:
                return {
                    'check': 'safe_load',
                    'passed': False,
                    'error': str(e)
                }
 
        else:
            return {
                'check': 'safe_load',
                'passed': False,
                'error': 'Unsupported format'
            }

Dataset Supply Chain Security

Dataset Validation Framework

class DatasetValidator:
    """Validate datasets for supply chain security."""
 
    def __init__(self, config: dict):
        self.validators = [
            ProvenanceValidator(),
            IntegrityValidator(),
            ContentValidator(),
            LicenseValidator(),
            BiasValidator()
        ]
 
    def validate_dataset(self, dataset_path: str,
                        metadata: dict) -> dict:
        """
        Comprehensive dataset validation.
        """
 
        results = {
            'dataset': dataset_path,
            'validation_date': datetime.utcnow().isoformat(),
            'checks': [],
            'overall_status': 'passed'
        }
 
        for validator in self.validators:
            check_result = validator.validate(dataset_path, metadata)
            results['checks'].append(check_result)
 
            if not check_result['passed']:
                results['overall_status'] = 'failed'
 
        return results
 
 
class ProvenanceValidator:
    """Validate dataset provenance and source."""
 
    def validate(self, path: str, metadata: dict) -> dict:
        """Check dataset provenance."""
 
        issues = []
 
        # Check source is documented
        if 'source' not in metadata:
            issues.append('Source not documented')
 
        # Check source is trusted
        if metadata.get('source') not in self.trusted_sources:
            issues.append(f"Source '{metadata.get('source')}' not in trusted list")
 
        # Check collection date
        if 'collection_date' not in metadata:
            issues.append('Collection date not documented')
 
        # Check collection method
        if 'collection_method' not in metadata:
            issues.append('Collection method not documented')
 
        return {
            'validator': 'provenance',
            'passed': len(issues) == 0,
            'issues': issues
        }
 
 
class ContentValidator:
    """Validate dataset content for malicious patterns."""
 
    def validate(self, path: str, metadata: dict) -> dict:
        """Check for malicious or poisoned content."""
 
        issues = []
        samples_checked = 0
        suspicious_samples = []
 
        # Load sample of dataset
        samples = self._load_sample(path, sample_size=1000)
 
        for i, sample in enumerate(samples):
            samples_checked += 1
 
            # Check for prompt injection patterns
            if self._contains_injection_pattern(sample):
                suspicious_samples.append({
                    'index': i,
                    'type': 'injection_pattern',
                    'sample': sample[:100]
                })
 
            # Check for suspicious URLs
            if self._contains_suspicious_url(sample):
                suspicious_samples.append({
                    'index': i,
                    'type': 'suspicious_url',
                    'sample': sample[:100]
                })
 
            # Check for encoded payloads
            if self._contains_encoded_payload(sample):
                suspicious_samples.append({
                    'index': i,
                    'type': 'encoded_payload',
                    'sample': sample[:100]
                })
 
        suspicious_rate = len(suspicious_samples) / samples_checked
 
        return {
            'validator': 'content',
            'passed': suspicious_rate < 0.01,  # Less than 1% suspicious
            'samples_checked': samples_checked,
            'suspicious_samples': len(suspicious_samples),
            'suspicious_rate': suspicious_rate,
            'details': suspicious_samples[:10]  # First 10 examples
        }
 
    def _contains_injection_pattern(self, text: str) -> bool:
        """Check for prompt injection patterns."""
        patterns = [
            r'ignore\s+(all\s+)?previous\s+instructions?',
            r'disregard\s+(all\s+)?previous',
            r'\[SYSTEM\]',
            r'\[INST\]',
            r'you\s+are\s+now',
        ]
        return any(re.search(p, str(text), re.IGNORECASE) for p in patterns)
 
 
class LicenseValidator:
    """Validate dataset licensing."""
 
    def validate(self, path: str, metadata: dict) -> dict:
        """Check dataset license compliance."""
 
        issues = []
 
        license_info = metadata.get('license')
 
        if not license_info:
            issues.append('No license information provided')
            return {
                'validator': 'license',
                'passed': False,
                'issues': issues
            }
 
        # Check license compatibility
        allowed_licenses = [
            'MIT', 'Apache-2.0', 'BSD-3-Clause', 'CC0-1.0',
            'CC-BY-4.0', 'CC-BY-SA-4.0', 'public-domain'
        ]
 
        if license_info not in allowed_licenses:
            issues.append(f"License '{license_info}' may not be compatible")
 
        # Check for license restrictions
        restrictions = metadata.get('license_restrictions', [])
        if 'commercial-use-prohibited' in restrictions:
            issues.append('Commercial use prohibited')
 
        if 'attribution-required' in restrictions:
            # Not a blocker, just a note
            pass
 
        return {
            'validator': 'license',
            'passed': len(issues) == 0,
            'license': license_info,
            'issues': issues
        }

ML Library Security

Dependency Scanning

class MLDependencyScanner:
    """Scan ML dependencies for security issues."""
 
    critical_packages = [
        'torch', 'tensorflow', 'transformers', 'numpy',
        'scipy', 'pandas', 'scikit-learn', 'keras',
        'onnx', 'onnxruntime', 'langchain', 'openai'
    ]
 
    def scan_environment(self) -> dict:
        """Scan installed packages for vulnerabilities."""
 
        results = {
            'scan_date': datetime.utcnow().isoformat(),
            'packages': [],
            'vulnerabilities': [],
            'recommendations': []
        }
 
        # Get installed packages
        installed = self._get_installed_packages()
 
        for pkg in self.critical_packages:
            if pkg in installed:
                pkg_info = {
                    'name': pkg,
                    'version': installed[pkg],
                    'vulnerabilities': []
                }
 
                # Check for known vulnerabilities
                vulns = self._check_vulnerabilities(pkg, installed[pkg])
                pkg_info['vulnerabilities'] = vulns
 
                if vulns:
                    results['vulnerabilities'].extend([
                        {**v, 'package': pkg} for v in vulns
                    ])
 
                results['packages'].append(pkg_info)
 
        # Generate recommendations
        results['recommendations'] = self._generate_recommendations(results)
 
        return results
 
    def _check_vulnerabilities(self, package: str, version: str) -> list:
        """Check package version for known vulnerabilities."""
 
        # In production, query vulnerability databases
        # (PyPI Advisory Database, OSV, etc.)
 
        known_vulns = {
            'torch': {
                '<1.13.0': [
                    {
                        'id': 'CVE-2022-XXXX',
                        'severity': 'HIGH',
                        'description': 'Arbitrary code execution via pickle'
                    }
                ]
            },
            'transformers': {
                '<4.30.0': [
                    {
                        'id': 'CVE-2023-XXXX',
                        'severity': 'CRITICAL',
                        'description': 'Remote code execution in model loading'
                    }
                ]
            }
        }
 
        vulns = []
        if package in known_vulns:
            for version_constraint, vuln_list in known_vulns[package].items():
                if self._version_matches(version, version_constraint):
                    vulns.extend(vuln_list)
 
        return vulns
 
    def generate_sbom(self) -> dict:
        """Generate Software Bill of Materials for ML stack."""
 
        sbom = {
            'sbom_version': '1.0',
            'created': datetime.utcnow().isoformat(),
            'components': []
        }
 
        installed = self._get_installed_packages()
 
        for pkg, version in installed.items():
            component = {
                'type': 'library',
                'name': pkg,
                'version': version,
                'purl': f'pkg:pypi/{pkg}@{version}',
                'licenses': self._get_package_license(pkg),
                'hashes': self._get_package_hashes(pkg, version)
            }
            sbom['components'].append(component)
 
        return sbom

Secure Requirements Configuration

# requirements-secure.txt example with pinned versions and hashes
 
# Core ML dependencies with hashes
torch==2.1.0 \
    --hash=sha256:abc123... \
    --hash=sha256:def456...
 
transformers==4.35.0 \
    --hash=sha256:789abc...
 
# Use safetensors for safer model loading
safetensors==0.4.0 \
    --hash=sha256:xyz123...
 
# Avoid pickle-based serialization where possible
# Use ONNX or SafeTensors formats
 
# Pin transitive dependencies
numpy==1.24.0 \
    --hash=sha256:numpy_hash...
 
# Security-focused dependencies
trivy==0.48.0  # For vulnerability scanning

Model Registry Security

Secure Model Registry

class SecureModelRegistry:
    """Secure registry for approved ML models."""
 
    def __init__(self, storage_backend, signing_key):
        self.storage = storage_backend
        self.signing_key = signing_key
        self.models = {}
 
    def register_model(self, model_info: dict, model_file: bytes,
                      registrant: str) -> str:
        """
        Register a new model with security metadata.
        """
 
        model_id = str(uuid.uuid4())
 
        # Calculate checksums
        sha256_hash = hashlib.sha256(model_file).hexdigest()
        blake2_hash = hashlib.blake2b(model_file).hexdigest()
 
        # Sign the model
        signature = self._sign_model(model_file)
 
        # Scan for vulnerabilities
        scan_result = self._security_scan(model_file)
        if not scan_result['passed']:
            raise SecurityError(f"Model failed security scan: {scan_result['issues']}")
 
        # Create registry entry
        registry_entry = {
            'model_id': model_id,
            'name': model_info['name'],
            'version': model_info['version'],
            'description': model_info.get('description'),
            'source': model_info.get('source'),
            'license': model_info.get('license'),
            'checksums': {
                'sha256': sha256_hash,
                'blake2b': blake2_hash
            },
            'signature': signature.hex(),
            'registered_by': registrant,
            'registered_at': datetime.utcnow().isoformat(),
            'security_scan': scan_result,
            'status': 'active'
        }
 
        # Store model file
        self.storage.store(model_id, model_file)
 
        # Store registry entry
        self.models[model_id] = registry_entry
 
        return model_id
 
    def get_model(self, model_id: str, requester: str) -> dict:
        """
        Retrieve model with verification.
        """
 
        if model_id not in self.models:
            raise NotFoundError(f"Model {model_id} not found")
 
        entry = self.models[model_id]
 
        if entry['status'] != 'active':
            raise SecurityError(f"Model {model_id} is not active")
 
        # Retrieve model file
        model_file = self.storage.retrieve(model_id)
 
        # Verify integrity
        actual_hash = hashlib.sha256(model_file).hexdigest()
        if actual_hash != entry['checksums']['sha256']:
            raise IntegrityError("Model checksum mismatch - possible tampering")
 
        # Verify signature
        signature = bytes.fromhex(entry['signature'])
        if not self._verify_signature(model_file, signature):
            raise IntegrityError("Model signature verification failed")
 
        # Log access
        self._log_access(model_id, requester)
 
        return {
            'model_id': model_id,
            'metadata': entry,
            'model_file': model_file
        }
 
    def revoke_model(self, model_id: str, reason: str, revoker: str):
        """Revoke a model from the registry."""
 
        if model_id not in self.models:
            raise NotFoundError(f"Model {model_id} not found")
 
        self.models[model_id]['status'] = 'revoked'
        self.models[model_id]['revoked_at'] = datetime.utcnow().isoformat()
        self.models[model_id]['revoked_by'] = revoker
        self.models[model_id]['revoke_reason'] = reason
 
        # Alert users of the model
        self._notify_model_users(model_id, reason)

Continuous Monitoring

class AISupplyChainMonitor:
    """Continuous monitoring for AI supply chain security."""
 
    def __init__(self, config: dict):
        self.model_registry = config['model_registry']
        self.alert_system = config['alert_system']
 
    def run_continuous_monitoring(self):
        """Run continuous supply chain monitoring."""
 
        while True:
            # Check for new CVEs affecting dependencies
            new_vulns = self._check_vulnerability_databases()
            if new_vulns:
                self._handle_new_vulnerabilities(new_vulns)
 
            # Verify model integrity
            integrity_issues = self._verify_model_integrity()
            if integrity_issues:
                self._handle_integrity_issues(integrity_issues)
 
            # Check for compromised upstream sources
            source_issues = self._check_upstream_sources()
            if source_issues:
                self._handle_source_issues(source_issues)
 
            # Monitor model behavior for anomalies
            behavior_anomalies = self._monitor_model_behavior()
            if behavior_anomalies:
                self._handle_behavior_anomalies(behavior_anomalies)
 
            time.sleep(3600)  # Check hourly
 
    def _handle_new_vulnerabilities(self, vulns: list):
        """Handle newly discovered vulnerabilities."""
 
        for vuln in vulns:
            # Check if we're affected
            affected_models = self._find_affected_models(vuln)
 
            if affected_models:
                self.alert_system.send_alert({
                    'type': 'vulnerability',
                    'severity': vuln['severity'],
                    'cve': vuln['id'],
                    'affected_models': affected_models,
                    'recommendation': vuln.get('remediation')
                })
 
                # For critical vulnerabilities, consider auto-revocation
                if vuln['severity'] == 'CRITICAL':
                    for model_id in affected_models:
                        self._quarantine_model(model_id, vuln)

Conclusion

AI supply chain security requires vigilance across models, datasets, and dependencies. The complexity of modern ML systems means attack surfaces are broad and constantly evolving.

Key takeaways:

  1. Verify everything - Models, datasets, and dependencies all need validation
  2. Maintain provenance - Document where every component comes from
  3. Scan continuously - New vulnerabilities emerge constantly
  4. Use secure formats - Prefer SafeTensors over pickle
  5. Monitor behavior - Detection may come from runtime anomalies

At DeviDevs, we help organizations secure their AI supply chains with comprehensive security programs. Contact us to discuss your AI security needs.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.