AI Security

AI Supply Chain Security: Protecting ML Pipelines

DeviDevs Team
10 min read
#ai-security#supply-chain#mlops#model-security#data-integrity

AI supply chain security protects the entire machine learning lifecycle from data collection through model deployment. This guide covers security implementations for protecting ML pipelines.

Model Provenance Tracking

Implement comprehensive model lineage tracking:

from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Optional, Any
import hashlib
import json
 
@dataclass
class DatasetInfo:
    dataset_id: str
    name: str
    version: str
    source: str
    hash: str
    size_bytes: int
    row_count: int
    created_at: datetime
    schema: Dict
    transformations: List[str]
 
@dataclass
class TrainingConfig:
    framework: str
    framework_version: str
    hyperparameters: Dict
    random_seed: int
    epochs: int
    batch_size: int
    optimizer: str
    learning_rate: float
 
@dataclass
class ModelArtifact:
    model_id: str
    name: str
    version: str
    model_hash: str
    file_path: str
    size_bytes: int
    format: str
    created_at: datetime
    signature: Optional[str] = None
 
@dataclass
class ModelProvenance:
    model_id: str
    model_name: str
    version: str
    created_at: datetime
    created_by: str
    datasets: List[DatasetInfo]
    training_config: TrainingConfig
    dependencies: Dict[str, str]
    environment: Dict
    metrics: Dict[str, float]
    artifact: ModelArtifact
    parent_model: Optional[str] = None
    attestations: List[Dict] = field(default_factory=list)
 
class ProvenanceTracker:
    def __init__(self, storage_backend):
        self.storage = storage_backend
        self.current_provenance: Optional[ModelProvenance] = None
 
    def start_training_run(
        self,
        model_name: str,
        version: str,
        created_by: str
    ) -> str:
        """Start tracking a new training run."""
        model_id = self._generate_model_id(model_name, version)
 
        self.current_provenance = ModelProvenance(
            model_id=model_id,
            model_name=model_name,
            version=version,
            created_at=datetime.utcnow(),
            created_by=created_by,
            datasets=[],
            training_config=None,
            dependencies={},
            environment={},
            metrics={},
            artifact=None
        )
 
        # Capture environment
        self.current_provenance.environment = self._capture_environment()
        self.current_provenance.dependencies = self._capture_dependencies()
 
        return model_id
 
    def _generate_model_id(self, name: str, version: str) -> str:
        timestamp = datetime.utcnow().isoformat()
        content = f"{name}:{version}:{timestamp}"
        return hashlib.sha256(content.encode()).hexdigest()[:16]
 
    def _capture_environment(self) -> Dict:
        import platform
        import os
 
        return {
            'platform': platform.platform(),
            'python_version': platform.python_version(),
            'hostname': platform.node(),
            'cpu_count': os.cpu_count(),
            'cuda_available': self._check_cuda(),
            'timestamp': datetime.utcnow().isoformat()
        }
 
    def _check_cuda(self) -> bool:
        try:
            import torch
            return torch.cuda.is_available()
        except ImportError:
            return False
 
    def _capture_dependencies(self) -> Dict[str, str]:
        import pkg_resources
        return {
            pkg.key: pkg.version
            for pkg in pkg_resources.working_set
        }
 
    def register_dataset(
        self,
        name: str,
        version: str,
        source: str,
        data_path: str,
        transformations: List[str] = None
    ) -> DatasetInfo:
        """Register a dataset used in training."""
        dataset_hash = self._hash_file(data_path)
        file_size = self._get_file_size(data_path)
        row_count = self._count_rows(data_path)
        schema = self._infer_schema(data_path)
 
        dataset = DatasetInfo(
            dataset_id=self._generate_dataset_id(name, version),
            name=name,
            version=version,
            source=source,
            hash=dataset_hash,
            size_bytes=file_size,
            row_count=row_count,
            created_at=datetime.utcnow(),
            schema=schema,
            transformations=transformations or []
        )
 
        if self.current_provenance:
            self.current_provenance.datasets.append(dataset)
 
        return dataset
 
    def _generate_dataset_id(self, name: str, version: str) -> str:
        return hashlib.md5(f"{name}:{version}".encode()).hexdigest()[:12]
 
    def _hash_file(self, file_path: str) -> str:
        sha256 = hashlib.sha256()
        with open(file_path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        return sha256.hexdigest()
 
    def _get_file_size(self, file_path: str) -> int:
        import os
        return os.path.getsize(file_path)
 
    def _count_rows(self, file_path: str) -> int:
        if file_path.endswith('.csv'):
            import pandas as pd
            return len(pd.read_csv(file_path))
        return -1
 
    def _infer_schema(self, file_path: str) -> Dict:
        if file_path.endswith('.csv'):
            import pandas as pd
            df = pd.read_csv(file_path, nrows=100)
            return {col: str(dtype) for col, dtype in df.dtypes.items()}
        return {}
 
    def set_training_config(self, config: TrainingConfig):
        """Set training configuration."""
        if self.current_provenance:
            self.current_provenance.training_config = config
 
    def log_metrics(self, metrics: Dict[str, float]):
        """Log training/evaluation metrics."""
        if self.current_provenance:
            self.current_provenance.metrics.update(metrics)
 
    def register_model_artifact(
        self,
        model_path: str,
        model_format: str = 'pytorch'
    ) -> ModelArtifact:
        """Register the trained model artifact."""
        model_hash = self._hash_file(model_path)
 
        artifact = ModelArtifact(
            model_id=self.current_provenance.model_id,
            name=self.current_provenance.model_name,
            version=self.current_provenance.version,
            model_hash=model_hash,
            file_path=model_path,
            size_bytes=self._get_file_size(model_path),
            format=model_format,
            created_at=datetime.utcnow()
        )
 
        if self.current_provenance:
            self.current_provenance.artifact = artifact
 
        return artifact
 
    def sign_model(self, private_key_path: str):
        """Sign the model artifact for integrity verification."""
        from cryptography.hazmat.primitives import hashes, serialization
        from cryptography.hazmat.primitives.asymmetric import padding
 
        if not self.current_provenance or not self.current_provenance.artifact:
            raise ValueError("No model artifact to sign")
 
        # Load private key
        with open(private_key_path, 'rb') as f:
            private_key = serialization.load_pem_private_key(f.read(), password=None)
 
        # Sign the model hash
        model_hash = self.current_provenance.artifact.model_hash.encode()
        signature = private_key.sign(
            model_hash,
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
 
        import base64
        self.current_provenance.artifact.signature = base64.b64encode(signature).decode()
 
    def add_attestation(
        self,
        attestation_type: str,
        attester: str,
        claims: Dict
    ):
        """Add an attestation to the provenance record."""
        attestation = {
            'type': attestation_type,
            'attester': attester,
            'claims': claims,
            'timestamp': datetime.utcnow().isoformat()
        }
 
        if self.current_provenance:
            self.current_provenance.attestations.append(attestation)
 
    def finalize_and_store(self) -> Dict:
        """Finalize and store the provenance record."""
        if not self.current_provenance:
            raise ValueError("No active provenance record")
 
        provenance_dict = self._to_dict(self.current_provenance)
        self.storage.store(self.current_provenance.model_id, provenance_dict)
 
        return provenance_dict
 
    def _to_dict(self, provenance: ModelProvenance) -> Dict:
        return {
            'model_id': provenance.model_id,
            'model_name': provenance.model_name,
            'version': provenance.version,
            'created_at': provenance.created_at.isoformat(),
            'created_by': provenance.created_by,
            'datasets': [
                {
                    'dataset_id': d.dataset_id,
                    'name': d.name,
                    'version': d.version,
                    'source': d.source,
                    'hash': d.hash,
                    'row_count': d.row_count
                }
                for d in provenance.datasets
            ],
            'training_config': {
                'framework': provenance.training_config.framework,
                'hyperparameters': provenance.training_config.hyperparameters,
                'epochs': provenance.training_config.epochs
            } if provenance.training_config else None,
            'dependencies': provenance.dependencies,
            'environment': provenance.environment,
            'metrics': provenance.metrics,
            'artifact': {
                'model_hash': provenance.artifact.model_hash,
                'signature': provenance.artifact.signature,
                'format': provenance.artifact.format
            } if provenance.artifact else None,
            'attestations': provenance.attestations
        }
 
    def verify_model(self, model_id: str, model_path: str, public_key_path: str) -> Dict:
        """Verify model integrity against provenance record."""
        provenance = self.storage.get(model_id)
 
        if not provenance:
            return {'valid': False, 'error': 'Provenance record not found'}
 
        # Verify hash
        current_hash = self._hash_file(model_path)
        expected_hash = provenance['artifact']['model_hash']
 
        if current_hash != expected_hash:
            return {
                'valid': False,
                'error': 'Model hash mismatch',
                'expected': expected_hash,
                'actual': current_hash
            }
 
        # Verify signature if present
        if provenance['artifact'].get('signature'):
            signature_valid = self._verify_signature(
                current_hash,
                provenance['artifact']['signature'],
                public_key_path
            )
            if not signature_valid:
                return {'valid': False, 'error': 'Invalid signature'}
 
        return {
            'valid': True,
            'model_id': model_id,
            'verified_at': datetime.utcnow().isoformat()
        }
 
    def _verify_signature(self, model_hash: str, signature_b64: str, public_key_path: str) -> bool:
        from cryptography.hazmat.primitives import hashes, serialization
        from cryptography.hazmat.primitives.asymmetric import padding
        import base64
 
        with open(public_key_path, 'rb') as f:
            public_key = serialization.load_pem_public_key(f.read())
 
        try:
            signature = base64.b64decode(signature_b64)
            public_key.verify(
                signature,
                model_hash.encode(),
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return True
        except Exception:
            return False

ML Dependency Scanner

Scan ML dependencies for vulnerabilities:

from dataclasses import dataclass
from typing import List, Dict, Optional
import subprocess
import json
 
@dataclass
class VulnerabilityFinding:
    package: str
    installed_version: str
    vulnerable_versions: str
    severity: str
    cve_id: Optional[str]
    description: str
    fix_version: Optional[str]
 
class MLDependencyScanner:
    def __init__(self):
        self.ml_packages = [
            'torch', 'tensorflow', 'keras', 'scikit-learn', 'numpy',
            'pandas', 'scipy', 'transformers', 'huggingface-hub',
            'onnx', 'onnxruntime', 'mlflow', 'wandb', 'pytorch-lightning'
        ]
        self.known_vulns = self._load_known_vulnerabilities()
 
    def _load_known_vulnerabilities(self) -> Dict:
        # In production, fetch from vulnerability database
        return {
            'torch': [
                {
                    'versions': '<1.9.0',
                    'severity': 'HIGH',
                    'cve': 'CVE-2022-XXXXX',
                    'description': 'Arbitrary code execution via pickle',
                    'fix_version': '1.9.0'
                }
            ],
            'tensorflow': [
                {
                    'versions': '<2.8.4',
                    'severity': 'CRITICAL',
                    'cve': 'CVE-2022-35959',
                    'description': 'CHECK failure in FractionalMaxPoolGrad',
                    'fix_version': '2.8.4'
                }
            ],
            'numpy': [
                {
                    'versions': '<1.22.0',
                    'severity': 'HIGH',
                    'cve': 'CVE-2021-41495',
                    'description': 'NULL pointer dereference in numpy.sort',
                    'fix_version': '1.22.0'
                }
            ],
            'pillow': [
                {
                    'versions': '<9.0.0',
                    'severity': 'HIGH',
                    'cve': 'CVE-2022-22815',
                    'description': 'Path traversal vulnerability',
                    'fix_version': '9.0.0'
                }
            ]
        }
 
    def scan_environment(self) -> List[VulnerabilityFinding]:
        """Scan installed packages for vulnerabilities."""
        import pkg_resources
 
        findings = []
        installed = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
 
        for package, vulns in self.known_vulns.items():
            if package.lower() not in installed:
                continue
 
            installed_version = installed[package.lower()]
 
            for vuln in vulns:
                if self._version_in_range(installed_version, vuln['versions']):
                    findings.append(VulnerabilityFinding(
                        package=package,
                        installed_version=installed_version,
                        vulnerable_versions=vuln['versions'],
                        severity=vuln['severity'],
                        cve_id=vuln.get('cve'),
                        description=vuln['description'],
                        fix_version=vuln.get('fix_version')
                    ))
 
        return findings
 
    def _version_in_range(self, version: str, range_spec: str) -> bool:
        from packaging import version as pkg_version
 
        installed = pkg_version.parse(version)
 
        if range_spec.startswith('<'):
            return installed < pkg_version.parse(range_spec[1:])
        elif range_spec.startswith('<='):
            return installed <= pkg_version.parse(range_spec[2:])
        elif range_spec.startswith('>='):
            return installed >= pkg_version.parse(range_spec[2:])
        elif range_spec.startswith('>'):
            return installed > pkg_version.parse(range_spec[1:])
        elif range_spec.startswith('=='):
            return installed == pkg_version.parse(range_spec[2:])
 
        return False
 
    def scan_requirements_file(self, file_path: str) -> List[VulnerabilityFinding]:
        """Scan requirements file for vulnerabilities."""
        findings = []
 
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    continue
 
                # Parse package==version
                if '==' in line:
                    package, version = line.split('==')
                    package = package.strip()
                    version = version.strip()
 
                    if package.lower() in self.known_vulns:
                        for vuln in self.known_vulns[package.lower()]:
                            if self._version_in_range(version, vuln['versions']):
                                findings.append(VulnerabilityFinding(
                                    package=package,
                                    installed_version=version,
                                    vulnerable_versions=vuln['versions'],
                                    severity=vuln['severity'],
                                    cve_id=vuln.get('cve'),
                                    description=vuln['description'],
                                    fix_version=vuln.get('fix_version')
                                ))
 
        return findings
 
    def scan_model_dependencies(self, model_path: str) -> List[VulnerabilityFinding]:
        """Scan dependencies embedded in model files."""
        findings = []
 
        # For PyTorch models
        if model_path.endswith('.pt') or model_path.endswith('.pth'):
            try:
                import torch
                checkpoint = torch.load(model_path, map_location='cpu')
 
                if isinstance(checkpoint, dict) and 'dependencies' in checkpoint:
                    deps = checkpoint['dependencies']
                    for package, version in deps.items():
                        if package.lower() in self.known_vulns:
                            for vuln in self.known_vulns[package.lower()]:
                                if self._version_in_range(version, vuln['versions']):
                                    findings.append(VulnerabilityFinding(
                                        package=package,
                                        installed_version=version,
                                        vulnerable_versions=vuln['versions'],
                                        severity=vuln['severity'],
                                        cve_id=vuln.get('cve'),
                                        description=vuln['description'],
                                        fix_version=vuln.get('fix_version')
                                    ))
            except Exception as e:
                print(f"Failed to scan model: {e}")
 
        return findings
 
    def generate_report(self, findings: List[VulnerabilityFinding]) -> Dict:
        """Generate vulnerability report."""
        severity_counts = {'CRITICAL': 0, 'HIGH': 0, 'MEDIUM': 0, 'LOW': 0}
        for f in findings:
            severity_counts[f.severity] = severity_counts.get(f.severity, 0) + 1
 
        return {
            'scan_time': datetime.utcnow().isoformat(),
            'summary': {
                'total_vulnerabilities': len(findings),
                'by_severity': severity_counts
            },
            'findings': [
                {
                    'package': f.package,
                    'installed_version': f.installed_version,
                    'severity': f.severity,
                    'cve': f.cve_id,
                    'description': f.description,
                    'fix_version': f.fix_version
                }
                for f in sorted(findings, key=lambda x: {'CRITICAL': 0, 'HIGH': 1, 'MEDIUM': 2, 'LOW': 3}.get(x.severity, 4))
            ]
        }

Training Data Integrity

Protect training data integrity:

from dataclasses import dataclass
from typing import List, Dict, Optional
import hashlib
import json
 
@dataclass
class DataIntegrityCheck:
    check_type: str
    passed: bool
    details: Dict
    timestamp: datetime
 
class TrainingDataIntegrity:
    def __init__(self):
        self.checks: List[DataIntegrityCheck] = []
 
    def verify_data_source(self, source_config: Dict) -> DataIntegrityCheck:
        """Verify data source authenticity."""
        check = DataIntegrityCheck(
            check_type='source_verification',
            passed=True,
            details={},
            timestamp=datetime.utcnow()
        )
 
        # Verify source URL/path
        source_url = source_config.get('url') or source_config.get('path')
        if not source_url:
            check.passed = False
            check.details['error'] = 'No source URL or path specified'
            return check
 
        # Check if source is from trusted registry
        trusted_sources = [
            'huggingface.co',
            'kaggle.com',
            's3.amazonaws.com',
            'storage.googleapis.com'
        ]
 
        if any(ts in source_url for ts in trusted_sources):
            check.details['source_trusted'] = True
        else:
            check.details['source_trusted'] = False
            check.details['warning'] = 'Source not from trusted registry'
 
        self.checks.append(check)
        return check
 
    def verify_checksum(self, file_path: str, expected_hash: str, algorithm: str = 'sha256') -> DataIntegrityCheck:
        """Verify file checksum."""
        if algorithm == 'sha256':
            hasher = hashlib.sha256()
        elif algorithm == 'md5':
            hasher = hashlib.md5()
        else:
            raise ValueError(f"Unsupported algorithm: {algorithm}")
 
        with open(file_path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                hasher.update(chunk)
 
        actual_hash = hasher.hexdigest()
 
        check = DataIntegrityCheck(
            check_type='checksum_verification',
            passed=actual_hash == expected_hash,
            details={
                'algorithm': algorithm,
                'expected': expected_hash,
                'actual': actual_hash
            },
            timestamp=datetime.utcnow()
        )
 
        self.checks.append(check)
        return check
 
    def detect_data_poisoning(self, dataset, baseline_stats: Dict = None) -> DataIntegrityCheck:
        """Detect potential data poisoning."""
        import numpy as np
 
        check = DataIntegrityCheck(
            check_type='poisoning_detection',
            passed=True,
            details={'anomalies': []},
            timestamp=datetime.utcnow()
        )
 
        # Statistical analysis
        if hasattr(dataset, 'shape'):
            data = dataset
        else:
            data = np.array(dataset)
 
        current_stats = {
            'mean': float(np.mean(data)),
            'std': float(np.std(data)),
            'min': float(np.min(data)),
            'max': float(np.max(data))
        }
 
        check.details['current_stats'] = current_stats
 
        if baseline_stats:
            # Compare with baseline
            mean_shift = abs(current_stats['mean'] - baseline_stats['mean'])
            std_change = abs(current_stats['std'] - baseline_stats['std'])
 
            if mean_shift > baseline_stats['std'] * 2:
                check.details['anomalies'].append({
                    'type': 'mean_shift',
                    'magnitude': mean_shift,
                    'threshold': baseline_stats['std'] * 2
                })
                check.passed = False
 
            if std_change > baseline_stats['std']:
                check.details['anomalies'].append({
                    'type': 'variance_change',
                    'magnitude': std_change,
                    'threshold': baseline_stats['std']
                })
                check.passed = False
 
        # Check for outliers
        outlier_threshold = 3
        z_scores = np.abs((data - np.mean(data)) / np.std(data))
        outlier_ratio = np.sum(z_scores > outlier_threshold) / len(data)
 
        if outlier_ratio > 0.01:  # More than 1% outliers
            check.details['anomalies'].append({
                'type': 'high_outlier_ratio',
                'ratio': float(outlier_ratio),
                'threshold': 0.01
            })
            check.passed = False
 
        self.checks.append(check)
        return check
 
    def verify_label_integrity(self, labels, expected_distribution: Dict = None) -> DataIntegrityCheck:
        """Verify label distribution integrity."""
        import numpy as np
        from collections import Counter
 
        check = DataIntegrityCheck(
            check_type='label_integrity',
            passed=True,
            details={},
            timestamp=datetime.utcnow()
        )
 
        label_counts = Counter(labels)
        total = sum(label_counts.values())
        distribution = {k: v/total for k, v in label_counts.items()}
 
        check.details['current_distribution'] = distribution
 
        if expected_distribution:
            for label, expected_ratio in expected_distribution.items():
                actual_ratio = distribution.get(label, 0)
                if abs(actual_ratio - expected_ratio) > 0.1:  # 10% tolerance
                    check.details.setdefault('deviations', []).append({
                        'label': label,
                        'expected': expected_ratio,
                        'actual': actual_ratio
                    })
                    check.passed = False
 
        self.checks.append(check)
        return check
 
    def generate_integrity_report(self) -> Dict:
        """Generate comprehensive integrity report."""
        return {
            'report_time': datetime.utcnow().isoformat(),
            'overall_status': 'PASS' if all(c.passed for c in self.checks) else 'FAIL',
            'checks': [
                {
                    'type': c.check_type,
                    'passed': c.passed,
                    'details': c.details,
                    'timestamp': c.timestamp.isoformat()
                }
                for c in self.checks
            ]
        }

Conclusion

AI supply chain security requires protecting every stage of the ML lifecycle. Implement provenance tracking to maintain model lineage, scan dependencies for vulnerabilities, and verify training data integrity. Use model signing for deployment verification. Remember that AI supply chain attacks can be subtle - continuous monitoring and regular security assessments are essential for maintaining a secure ML pipeline.


Need help securing your ML supply chain? DeviDevs builds secure MLOps platforms with provenance tracking and integrity verification. Get a free assessment →

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.