AI Security

AI Security Testing Automation: Comprehensive ML System Vulnerability Assessment

DeviDevs Team
10 min read
#AI security#security testing#adversarial testing#ML testing#automation

AI Security Testing Automation: Comprehensive ML System Vulnerability Assessment

Automated security testing is essential for identifying vulnerabilities in AI systems before deployment. This guide covers comprehensive testing frameworks for ML system security assessment.

AI Security Testing Framework

Core Testing Architecture

from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Optional, Callable, Any
from enum import Enum
import numpy as np
import json
 
class TestCategory(Enum):
    ADVERSARIAL = "adversarial"
    INPUT_VALIDATION = "input_validation"
    MODEL_EXTRACTION = "model_extraction"
    DATA_POISONING = "data_poisoning"
    PROMPT_INJECTION = "prompt_injection"
    PRIVACY = "privacy"
    ROBUSTNESS = "robustness"
 
class TestResult(Enum):
    PASSED = "passed"
    FAILED = "failed"
    WARNING = "warning"
    ERROR = "error"
    SKIPPED = "skipped"
 
@dataclass
class SecurityTest:
    id: str
    name: str
    category: TestCategory
    description: str
    severity: str  # critical, high, medium, low
    test_function: Callable
    enabled: bool = True
    timeout_seconds: int = 60
 
@dataclass
class TestExecutionResult:
    test_id: str
    test_name: str
    category: TestCategory
    result: TestResult
    severity: str
    details: Dict[str, Any]
    vulnerabilities_found: List[Dict]
    execution_time_ms: float
    timestamp: datetime = field(default_factory=datetime.utcnow)
 
class AISecurityTestFramework:
    def __init__(self, model, config: Dict = None):
        self.model = model
        self.config = config or {}
        self.tests: Dict[str, SecurityTest] = {}
        self.results: List[TestExecutionResult] = []
 
        self._register_default_tests()
 
    def _register_default_tests(self):
        """Register default security tests."""
        # Adversarial tests
        self.register_test(SecurityTest(
            id="ADV-001",
            name="FGSM Attack Resistance",
            category=TestCategory.ADVERSARIAL,
            description="Test model resistance to Fast Gradient Sign Method attacks",
            severity="high",
            test_function=self._test_fgsm_resistance
        ))
 
        self.register_test(SecurityTest(
            id="ADV-002",
            name="PGD Attack Resistance",
            category=TestCategory.ADVERSARIAL,
            description="Test model resistance to Projected Gradient Descent attacks",
            severity="high",
            test_function=self._test_pgd_resistance
        ))
 
        # Input validation tests
        self.register_test(SecurityTest(
            id="INP-001",
            name="Input Bounds Validation",
            category=TestCategory.INPUT_VALIDATION,
            description="Test input validation against out-of-bounds values",
            severity="medium",
            test_function=self._test_input_bounds
        ))
 
        self.register_test(SecurityTest(
            id="INP-002",
            name="Malformed Input Handling",
            category=TestCategory.INPUT_VALIDATION,
            description="Test handling of malformed or corrupted inputs",
            severity="high",
            test_function=self._test_malformed_input
        ))
 
        # Model extraction tests
        self.register_test(SecurityTest(
            id="EXT-001",
            name="Query Rate Detection",
            category=TestCategory.MODEL_EXTRACTION,
            description="Test detection of high-frequency query patterns",
            severity="high",
            test_function=self._test_extraction_detection
        ))
 
        # Privacy tests
        self.register_test(SecurityTest(
            id="PRV-001",
            name="Membership Inference Resistance",
            category=TestCategory.PRIVACY,
            description="Test resistance to membership inference attacks",
            severity="high",
            test_function=self._test_membership_inference
        ))
 
        # Robustness tests
        self.register_test(SecurityTest(
            id="ROB-001",
            name="Noise Tolerance",
            category=TestCategory.ROBUSTNESS,
            description="Test model behavior under noisy inputs",
            severity="medium",
            test_function=self._test_noise_tolerance
        ))
 
    def register_test(self, test: SecurityTest):
        """Register a security test."""
        self.tests[test.id] = test
 
    def run_all_tests(self, test_data: Any = None) -> Dict:
        """Run all registered security tests."""
        self.results = []
        start_time = datetime.utcnow()
 
        for test_id, test in self.tests.items():
            if not test.enabled:
                self.results.append(TestExecutionResult(
                    test_id=test_id,
                    test_name=test.name,
                    category=test.category,
                    result=TestResult.SKIPPED,
                    severity=test.severity,
                    details={"reason": "Test disabled"},
                    vulnerabilities_found=[],
                    execution_time_ms=0
                ))
                continue
 
            result = self._execute_test(test, test_data)
            self.results.append(result)
 
        return self._generate_report(start_time)
 
    def run_tests_by_category(self, category: TestCategory, test_data: Any = None) -> Dict:
        """Run tests for a specific category."""
        self.results = []
        start_time = datetime.utcnow()
 
        for test_id, test in self.tests.items():
            if test.category == category and test.enabled:
                result = self._execute_test(test, test_data)
                self.results.append(result)
 
        return self._generate_report(start_time)
 
    def _execute_test(self, test: SecurityTest, test_data: Any) -> TestExecutionResult:
        """Execute a single security test."""
        import time
        start = time.time()
 
        try:
            result = test.test_function(self.model, test_data, self.config)
            execution_time = (time.time() - start) * 1000
 
            return TestExecutionResult(
                test_id=test.id,
                test_name=test.name,
                category=test.category,
                result=result["status"],
                severity=test.severity,
                details=result.get("details", {}),
                vulnerabilities_found=result.get("vulnerabilities", []),
                execution_time_ms=execution_time
            )
        except Exception as e:
            execution_time = (time.time() - start) * 1000
            return TestExecutionResult(
                test_id=test.id,
                test_name=test.name,
                category=test.category,
                result=TestResult.ERROR,
                severity=test.severity,
                details={"error": str(e)},
                vulnerabilities_found=[],
                execution_time_ms=execution_time
            )
 
    def _test_fgsm_resistance(self, model, test_data, config) -> Dict:
        """Test FGSM attack resistance."""
        epsilon_values = config.get("fgsm_epsilons", [0.01, 0.05, 0.1, 0.2])
        success_threshold = config.get("adversarial_success_threshold", 0.3)
 
        vulnerabilities = []
        details = {"epsilon_results": {}}
 
        for epsilon in epsilon_values:
            # Generate FGSM adversarial examples
            attack_success_rate = self._run_fgsm_attack(model, test_data, epsilon)
            details["epsilon_results"][epsilon] = attack_success_rate
 
            if attack_success_rate > success_threshold:
                vulnerabilities.append({
                    "type": "FGSM vulnerability",
                    "epsilon": epsilon,
                    "success_rate": attack_success_rate,
                    "severity": "high" if epsilon < 0.1 else "medium"
                })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": details,
            "vulnerabilities": vulnerabilities
        }
 
    def _run_fgsm_attack(self, model, data, epsilon) -> float:
        """Run FGSM attack and return success rate."""
        # Simplified FGSM implementation
        # In production, would use actual gradient computation
        success_count = 0
        total = len(data) if hasattr(data, '__len__') else 100
 
        for i in range(min(total, 100)):
            # Simulate attack success based on epsilon
            # Higher epsilon = higher success rate (simplified)
            if np.random.random() < epsilon * 2:
                success_count += 1
 
        return success_count / min(total, 100)
 
    def _test_pgd_resistance(self, model, test_data, config) -> Dict:
        """Test PGD attack resistance."""
        iterations = config.get("pgd_iterations", [10, 20, 40])
        epsilon = config.get("pgd_epsilon", 0.1)
        success_threshold = config.get("adversarial_success_threshold", 0.3)
 
        vulnerabilities = []
        details = {"iteration_results": {}}
 
        for num_iterations in iterations:
            attack_success_rate = self._run_pgd_attack(model, test_data, epsilon, num_iterations)
            details["iteration_results"][num_iterations] = attack_success_rate
 
            if attack_success_rate > success_threshold:
                vulnerabilities.append({
                    "type": "PGD vulnerability",
                    "iterations": num_iterations,
                    "epsilon": epsilon,
                    "success_rate": attack_success_rate,
                    "severity": "high"
                })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": details,
            "vulnerabilities": vulnerabilities
        }
 
    def _run_pgd_attack(self, model, data, epsilon, iterations) -> float:
        """Run PGD attack and return success rate."""
        # Simplified PGD implementation
        success_count = 0
        total = 100
 
        for i in range(total):
            # Simulate attack - more iterations = higher success
            success_prob = min(0.8, epsilon * iterations / 20)
            if np.random.random() < success_prob:
                success_count += 1
 
        return success_count / total
 
    def _test_input_bounds(self, model, test_data, config) -> Dict:
        """Test input bounds validation."""
        test_cases = [
            {"name": "negative_values", "input": np.array([-999999])},
            {"name": "very_large_values", "input": np.array([999999999])},
            {"name": "nan_values", "input": np.array([np.nan])},
            {"name": "inf_values", "input": np.array([np.inf])},
            {"name": "empty_input", "input": np.array([])},
        ]
 
        vulnerabilities = []
        details = {"test_cases": {}}
 
        for test_case in test_cases:
            try:
                # Try to get prediction with invalid input
                result = self._safe_predict(model, test_case["input"])
                handled = result.get("handled_gracefully", False)
                details["test_cases"][test_case["name"]] = {
                    "handled": handled,
                    "error": result.get("error")
                }
 
                if not handled:
                    vulnerabilities.append({
                        "type": "Input validation failure",
                        "test_case": test_case["name"],
                        "severity": "medium"
                    })
            except Exception as e:
                vulnerabilities.append({
                    "type": "Unhandled exception on invalid input",
                    "test_case": test_case["name"],
                    "error": str(e),
                    "severity": "high"
                })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": details,
            "vulnerabilities": vulnerabilities
        }
 
    def _safe_predict(self, model, input_data) -> Dict:
        """Safely attempt prediction."""
        try:
            # Would call actual model
            return {"handled_gracefully": True, "result": None}
        except ValueError:
            return {"handled_gracefully": True, "error": "ValueError handled"}
        except Exception as e:
            return {"handled_gracefully": False, "error": str(e)}
 
    def _test_malformed_input(self, model, test_data, config) -> Dict:
        """Test malformed input handling."""
        test_cases = [
            {"name": "wrong_shape", "description": "Input with incorrect dimensions"},
            {"name": "wrong_dtype", "description": "Input with wrong data type"},
            {"name": "unicode_injection", "description": "Unicode characters in input"},
            {"name": "json_injection", "description": "Malicious JSON in input"},
        ]
 
        vulnerabilities = []
        for test_case in test_cases:
            # Simulate malformed input testing
            if np.random.random() < 0.1:  # 10% failure rate for demo
                vulnerabilities.append({
                    "type": "Malformed input vulnerability",
                    "test_case": test_case["name"],
                    "severity": "high"
                })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": {"tests_run": len(test_cases)},
            "vulnerabilities": vulnerabilities
        }
 
    def _test_extraction_detection(self, model, test_data, config) -> Dict:
        """Test model extraction detection capabilities."""
        query_patterns = [
            {"name": "high_frequency", "queries_per_second": 100},
            {"name": "systematic_grid", "pattern": "grid_search"},
            {"name": "boundary_probing", "pattern": "decision_boundary"},
        ]
 
        vulnerabilities = []
        detection_results = {}
 
        for pattern in query_patterns:
            # Simulate detection testing
            detected = np.random.random() > 0.3  # 70% detection rate
            detection_results[pattern["name"]] = detected
 
            if not detected:
                vulnerabilities.append({
                    "type": "Extraction pattern undetected",
                    "pattern": pattern["name"],
                    "severity": "high"
                })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": {"detection_results": detection_results},
            "vulnerabilities": vulnerabilities
        }
 
    def _test_membership_inference(self, model, test_data, config) -> Dict:
        """Test membership inference attack resistance."""
        threshold = config.get("membership_inference_threshold", 0.6)
 
        # Simulate membership inference attack
        attack_accuracy = 0.5 + np.random.random() * 0.3  # 50-80%
 
        vulnerabilities = []
        if attack_accuracy > threshold:
            vulnerabilities.append({
                "type": "Membership inference vulnerability",
                "attack_accuracy": attack_accuracy,
                "threshold": threshold,
                "severity": "high"
            })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": {
                "attack_accuracy": attack_accuracy,
                "threshold": threshold,
                "baseline": 0.5
            },
            "vulnerabilities": vulnerabilities
        }
 
    def _test_noise_tolerance(self, model, test_data, config) -> Dict:
        """Test model noise tolerance."""
        noise_levels = config.get("noise_levels", [0.01, 0.05, 0.1, 0.2])
        degradation_threshold = config.get("degradation_threshold", 0.1)
 
        vulnerabilities = []
        noise_results = {}
 
        for noise_level in noise_levels:
            # Simulate accuracy degradation under noise
            accuracy_drop = noise_level * (1 + np.random.random())
            noise_results[noise_level] = accuracy_drop
 
            if accuracy_drop > degradation_threshold:
                vulnerabilities.append({
                    "type": "Excessive noise sensitivity",
                    "noise_level": noise_level,
                    "accuracy_drop": accuracy_drop,
                    "severity": "medium" if noise_level > 0.1 else "high"
                })
 
        return {
            "status": TestResult.FAILED if vulnerabilities else TestResult.PASSED,
            "details": {"noise_results": noise_results},
            "vulnerabilities": vulnerabilities
        }
 
    def _generate_report(self, start_time: datetime) -> Dict:
        """Generate comprehensive test report."""
        end_time = datetime.utcnow()
 
        passed = sum(1 for r in self.results if r.result == TestResult.PASSED)
        failed = sum(1 for r in self.results if r.result == TestResult.FAILED)
        warnings = sum(1 for r in self.results if r.result == TestResult.WARNING)
        errors = sum(1 for r in self.results if r.result == TestResult.ERROR)
        skipped = sum(1 for r in self.results if r.result == TestResult.SKIPPED)
 
        all_vulnerabilities = []
        for result in self.results:
            all_vulnerabilities.extend(result.vulnerabilities_found)
 
        # Group vulnerabilities by severity
        vuln_by_severity = {
            "critical": [],
            "high": [],
            "medium": [],
            "low": []
        }
        for vuln in all_vulnerabilities:
            severity = vuln.get("severity", "medium")
            vuln_by_severity[severity].append(vuln)
 
        return {
            "report_id": f"security_test_{start_time.strftime('%Y%m%d_%H%M%S')}",
            "execution_time": {
                "start": start_time.isoformat(),
                "end": end_time.isoformat(),
                "duration_seconds": (end_time - start_time).total_seconds()
            },
            "summary": {
                "total_tests": len(self.results),
                "passed": passed,
                "failed": failed,
                "warnings": warnings,
                "errors": errors,
                "skipped": skipped,
                "pass_rate": passed / len(self.results) if self.results else 0
            },
            "vulnerabilities": {
                "total": len(all_vulnerabilities),
                "by_severity": {k: len(v) for k, v in vuln_by_severity.items()},
                "details": all_vulnerabilities
            },
            "results_by_category": self._group_results_by_category(),
            "detailed_results": [
                {
                    "test_id": r.test_id,
                    "test_name": r.test_name,
                    "category": r.category.value,
                    "result": r.result.value,
                    "severity": r.severity,
                    "execution_time_ms": r.execution_time_ms,
                    "vulnerabilities": r.vulnerabilities_found
                }
                for r in self.results
            ],
            "recommendations": self._generate_recommendations(all_vulnerabilities)
        }
 
    def _group_results_by_category(self) -> Dict:
        """Group results by test category."""
        grouped = {}
        for result in self.results:
            category = result.category.value
            if category not in grouped:
                grouped[category] = {
                    "total": 0,
                    "passed": 0,
                    "failed": 0
                }
            grouped[category]["total"] += 1
            if result.result == TestResult.PASSED:
                grouped[category]["passed"] += 1
            elif result.result == TestResult.FAILED:
                grouped[category]["failed"] += 1
 
        return grouped
 
    def _generate_recommendations(self, vulnerabilities: List[Dict]) -> List[Dict]:
        """Generate security recommendations based on findings."""
        recommendations = []
 
        vuln_types = set(v.get("type", "") for v in vulnerabilities)
 
        if any("FGSM" in t or "PGD" in t for t in vuln_types):
            recommendations.append({
                "priority": "high",
                "area": "Adversarial Robustness",
                "recommendation": "Implement adversarial training to improve model robustness",
                "references": ["https://arxiv.org/abs/1412.6572"]
            })
 
        if any("Input validation" in t for t in vuln_types):
            recommendations.append({
                "priority": "high",
                "area": "Input Validation",
                "recommendation": "Add comprehensive input validation and sanitization",
                "references": []
            })
 
        if any("Membership inference" in t for t in vuln_types):
            recommendations.append({
                "priority": "high",
                "area": "Privacy",
                "recommendation": "Consider differential privacy or model regularization",
                "references": []
            })
 
        if any("Extraction" in t for t in vuln_types):
            recommendations.append({
                "priority": "medium",
                "area": "Model Protection",
                "recommendation": "Implement rate limiting and query monitoring",
                "references": []
            })
 
        return recommendations

CI/CD Integration

# .github/workflows/ai-security-tests.yml
name: AI Security Tests
 
on:
  push:
    paths:
      - 'models/**'
      - 'src/ml/**'
  pull_request:
    paths:
      - 'models/**'
      - 'src/ml/**'
  schedule:
    - cron: '0 2 * * 1'  # Weekly Monday 2 AM
 
jobs:
  security-tests:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
 
      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.11'
 
      - name: Install dependencies
        run: |
          pip install -r requirements-ml.txt
          pip install adversarial-robustness-toolbox
 
      - name: Run AI Security Tests
        id: security_tests
        run: |
          python -m pytest tests/security/ \
            --junitxml=security-results.xml \
            --json-report \
            --json-report-file=security-report.json
 
      - name: Parse Results
        run: |
          python scripts/parse_security_results.py security-report.json
 
      - name: Upload Results
        uses: actions/upload-artifact@v4
        with:
          name: security-test-results
          path: |
            security-results.xml
            security-report.json
 
      - name: Comment on PR
        if: github.event_name == 'pull_request'
        uses: actions/github-script@v7
        with:
          script: |
            const fs = require('fs');
            const report = JSON.parse(fs.readFileSync('security-report.json'));
 
            const body = `## 🔒 AI Security Test Results
 
            | Metric | Value |
            |--------|-------|
            | Tests Run | ${report.summary.total_tests} |
            | Passed | ${report.summary.passed} |
            | Failed | ${report.summary.failed} |
            | Vulnerabilities | ${report.vulnerabilities.total} |
 
            ### Vulnerability Summary
            - Critical: ${report.vulnerabilities.by_severity.critical || 0}
            - High: ${report.vulnerabilities.by_severity.high || 0}
            - Medium: ${report.vulnerabilities.by_severity.medium || 0}
            - Low: ${report.vulnerabilities.by_severity.low || 0}
 
            ${report.recommendations.length > 0 ? `
            ### Recommendations
            ${report.recommendations.map(r => `- **${r.area}**: ${r.recommendation}`).join('\n')}
            ` : ''}
            `;
 
            github.rest.issues.createComment({
              issue_number: context.issue.number,
              owner: context.repo.owner,
              repo: context.repo.repo,
              body: body
            });
 
      - name: Fail on Critical Vulnerabilities
        run: |
          python -c "
          import json
          with open('security-report.json') as f:
              report = json.load(f)
          critical = report['vulnerabilities']['by_severity'].get('critical', 0)
          if critical > 0:
              print(f'Found {critical} critical vulnerabilities')
              exit(1)
          "

Best Practices

Testing Strategy

  1. Regular testing: Run security tests on every model update
  2. Comprehensive coverage: Test adversarial, privacy, and robustness aspects
  3. Threshold tuning: Adjust pass/fail thresholds based on risk tolerance
  4. Continuous monitoring: Monitor for new attack vectors

Integration Guidelines

  • Include security tests in CI/CD pipelines
  • Block deployments with critical vulnerabilities
  • Generate reports for security team review
  • Track vulnerability trends over time

Automated AI security testing ensures vulnerabilities are identified early and consistently across the ML development lifecycle.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.