DevSecOps

DevSecOps Threat Modeling: Integrating Security Design into CI/CD Pipelines

DeviDevs Team
15 min read
#threat modeling#DevSecOps#STRIDE#security design#CI/CD

DevSecOps Threat Modeling: Integrating Security Design into CI/CD Pipelines

Threat modeling identifies potential security threats early in development. This guide shows how to integrate threat modeling into DevSecOps pipelines for continuous security validation.

STRIDE Threat Modeling Framework

Automated STRIDE Analysis

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Optional
import json
 
class ThreatCategory(Enum):
    SPOOFING = "Spoofing"
    TAMPERING = "Tampering"
    REPUDIATION = "Repudiation"
    INFORMATION_DISCLOSURE = "Information Disclosure"
    DENIAL_OF_SERVICE = "Denial of Service"
    ELEVATION_OF_PRIVILEGE = "Elevation of Privilege"
 
@dataclass
class Component:
    id: str
    name: str
    type: str  # web_app, api, database, external_service, user
    trust_level: int  # 0-10, higher = more trusted
    data_sensitivity: str  # public, internal, confidential, restricted
    authentication_required: bool = True
    connections: List[str] = field(default_factory=list)
 
@dataclass
class DataFlow:
    id: str
    source: str
    destination: str
    protocol: str
    data_type: str
    encrypted: bool = False
    authenticated: bool = False
 
@dataclass
class Threat:
    id: str
    category: ThreatCategory
    component_id: str
    description: str
    risk_score: float
    mitigations: List[str]
    status: str = "identified"
 
class STRIDEAnalyzer:
    def __init__(self):
        self.components: Dict[str, Component] = {}
        self.data_flows: List[DataFlow] = []
        self.threats: List[Threat] = []
        self.threat_counter = 0
 
    def add_component(self, component: Component):
        self.components[component.id] = component
 
    def add_data_flow(self, flow: DataFlow):
        self.data_flows.append(flow)
 
    def analyze(self) -> List[Threat]:
        """Run complete STRIDE analysis."""
        self.threats = []
 
        for component in self.components.values():
            self._analyze_spoofing(component)
            self._analyze_tampering(component)
            self._analyze_repudiation(component)
            self._analyze_information_disclosure(component)
            self._analyze_denial_of_service(component)
            self._analyze_elevation_of_privilege(component)
 
        for flow in self.data_flows:
            self._analyze_data_flow_threats(flow)
 
        return sorted(self.threats, key=lambda t: t.risk_score, reverse=True)
 
    def _generate_threat_id(self) -> str:
        self.threat_counter += 1
        return f"THR-{self.threat_counter:04d}"
 
    def _analyze_spoofing(self, component: Component):
        """Analyze spoofing threats."""
        if component.type == "user":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.SPOOFING,
                component_id=component.id,
                description=f"Attacker could impersonate legitimate user '{component.name}'",
                risk_score=self._calculate_risk(component, 8),
                mitigations=[
                    "Implement multi-factor authentication",
                    "Use strong password policies",
                    "Implement account lockout mechanisms",
                    "Monitor for suspicious login patterns"
                ]
            ))
 
        if component.type == "api" and not component.authentication_required:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.SPOOFING,
                component_id=component.id,
                description=f"API '{component.name}' lacks authentication, allowing request spoofing",
                risk_score=self._calculate_risk(component, 9),
                mitigations=[
                    "Implement API authentication (OAuth2, API keys)",
                    "Use mutual TLS for service-to-service communication",
                    "Validate request signatures"
                ]
            ))
 
        if component.type == "external_service":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.SPOOFING,
                component_id=component.id,
                description=f"External service '{component.name}' could be spoofed by attacker",
                risk_score=self._calculate_risk(component, 7),
                mitigations=[
                    "Verify SSL/TLS certificates",
                    "Use certificate pinning",
                    "Implement webhook signature validation"
                ]
            ))
 
    def _analyze_tampering(self, component: Component):
        """Analyze tampering threats."""
        if component.type == "database":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.TAMPERING,
                component_id=component.id,
                description=f"Data in '{component.name}' could be modified without authorization",
                risk_score=self._calculate_risk(component, 9),
                mitigations=[
                    "Implement row-level security",
                    "Use database audit logging",
                    "Encrypt sensitive data at rest",
                    "Implement integrity checks (checksums, signatures)"
                ]
            ))
 
        if component.type == "api":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.TAMPERING,
                component_id=component.id,
                description=f"API requests to '{component.name}' could be tampered in transit",
                risk_score=self._calculate_risk(component, 7),
                mitigations=[
                    "Use HTTPS/TLS for all communications",
                    "Implement request signing",
                    "Validate input data thoroughly",
                    "Use HMAC for message integrity"
                ]
            ))
 
    def _analyze_repudiation(self, component: Component):
        """Analyze repudiation threats."""
        if component.trust_level < 5:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.REPUDIATION,
                component_id=component.id,
                description=f"Actions performed by '{component.name}' may not be properly logged",
                risk_score=self._calculate_risk(component, 6),
                mitigations=[
                    "Implement comprehensive audit logging",
                    "Use tamper-evident logging (append-only)",
                    "Include timestamps and user identifiers",
                    "Store logs in centralized, secure location"
                ]
            ))
 
        if component.data_sensitivity in ["confidential", "restricted"]:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.REPUDIATION,
                component_id=component.id,
                description=f"Access to sensitive data in '{component.name}' may be denied by users",
                risk_score=self._calculate_risk(component, 7),
                mitigations=[
                    "Implement digital signatures for critical operations",
                    "Use non-repudiable transaction logs",
                    "Require acknowledgment for sensitive actions"
                ]
            ))
 
    def _analyze_information_disclosure(self, component: Component):
        """Analyze information disclosure threats."""
        if component.data_sensitivity in ["confidential", "restricted"]:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.INFORMATION_DISCLOSURE,
                component_id=component.id,
                description=f"Sensitive data in '{component.name}' could be exposed",
                risk_score=self._calculate_risk(component, 9),
                mitigations=[
                    "Encrypt data at rest and in transit",
                    "Implement proper access controls",
                    "Use data masking for non-production environments",
                    "Implement DLP (Data Loss Prevention) controls"
                ]
            ))
 
        if component.type == "api":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.INFORMATION_DISCLOSURE,
                component_id=component.id,
                description=f"API '{component.name}' could leak sensitive data in responses or errors",
                risk_score=self._calculate_risk(component, 7),
                mitigations=[
                    "Implement response filtering",
                    "Use generic error messages",
                    "Remove debug information in production",
                    "Validate authorization before data access"
                ]
            ))
 
    def _analyze_denial_of_service(self, component: Component):
        """Analyze denial of service threats."""
        if component.type in ["api", "web_app"]:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.DENIAL_OF_SERVICE,
                component_id=component.id,
                description=f"'{component.name}' could be overwhelmed by malicious requests",
                risk_score=self._calculate_risk(component, 7),
                mitigations=[
                    "Implement rate limiting",
                    "Use DDoS protection services",
                    "Implement request queuing",
                    "Set resource limits and timeouts",
                    "Use auto-scaling infrastructure"
                ]
            ))
 
        if component.type == "database":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.DENIAL_OF_SERVICE,
                component_id=component.id,
                description=f"Database '{component.name}' could be exhausted by expensive queries",
                risk_score=self._calculate_risk(component, 8),
                mitigations=[
                    "Implement query timeouts",
                    "Use connection pooling",
                    "Limit query complexity",
                    "Implement caching layers"
                ]
            ))
 
    def _analyze_elevation_of_privilege(self, component: Component):
        """Analyze elevation of privilege threats."""
        if component.type == "api":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.ELEVATION_OF_PRIVILEGE,
                component_id=component.id,
                description=f"Attacker could gain elevated access through '{component.name}'",
                risk_score=self._calculate_risk(component, 9),
                mitigations=[
                    "Implement principle of least privilege",
                    "Use role-based access control (RBAC)",
                    "Validate permissions on every request",
                    "Implement proper session management"
                ]
            ))
 
        if component.type == "web_app":
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.ELEVATION_OF_PRIVILEGE,
                component_id=component.id,
                description=f"Client-side attacks could elevate privileges in '{component.name}'",
                risk_score=self._calculate_risk(component, 8),
                mitigations=[
                    "Implement server-side authorization checks",
                    "Use Content Security Policy (CSP)",
                    "Sanitize all user inputs",
                    "Implement proper CORS policies"
                ]
            ))
 
    def _analyze_data_flow_threats(self, flow: DataFlow):
        """Analyze threats specific to data flows."""
        source = self.components.get(flow.source)
        dest = self.components.get(flow.destination)
 
        if not source or not dest:
            return
 
        # Check for unencrypted sensitive data flows
        if not flow.encrypted and dest.data_sensitivity in ["confidential", "restricted"]:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.INFORMATION_DISCLOSURE,
                component_id=flow.id,
                description=f"Sensitive data flows unencrypted from {source.name} to {dest.name}",
                risk_score=9.0,
                mitigations=[
                    "Enable TLS/SSL encryption",
                    "Use VPN for internal communications",
                    "Implement end-to-end encryption for sensitive data"
                ]
            ))
 
        # Check trust boundary crossings
        if abs(source.trust_level - dest.trust_level) > 3:
            self.threats.append(Threat(
                id=self._generate_threat_id(),
                category=ThreatCategory.TAMPERING,
                component_id=flow.id,
                description=f"Data crosses significant trust boundary: {source.name} -> {dest.name}",
                risk_score=7.5,
                mitigations=[
                    "Implement strict input validation at boundary",
                    "Use message signing/verification",
                    "Add additional authentication at boundary"
                ]
            ))
 
    def _calculate_risk(self, component: Component, base_score: float) -> float:
        """Calculate risk score based on component properties."""
        score = base_score
 
        # Adjust for data sensitivity
        sensitivity_multiplier = {
            "public": 0.5,
            "internal": 0.75,
            "confidential": 1.0,
            "restricted": 1.25
        }
        score *= sensitivity_multiplier.get(component.data_sensitivity, 1.0)
 
        # Adjust for trust level (lower trust = higher risk)
        score *= (11 - component.trust_level) / 10
 
        return min(10.0, round(score, 1))
 
    def generate_report(self) -> Dict:
        """Generate threat modeling report."""
        threats_by_category = {}
        for threat in self.threats:
            category = threat.category.value
            if category not in threats_by_category:
                threats_by_category[category] = []
            threats_by_category[category].append({
                "id": threat.id,
                "component": threat.component_id,
                "description": threat.description,
                "risk_score": threat.risk_score,
                "mitigations": threat.mitigations,
                "status": threat.status
            })
 
        high_risk = [t for t in self.threats if t.risk_score >= 8.0]
        medium_risk = [t for t in self.threats if 5.0 <= t.risk_score < 8.0]
        low_risk = [t for t in self.threats if t.risk_score < 5.0]
 
        return {
            "summary": {
                "total_threats": len(self.threats),
                "high_risk": len(high_risk),
                "medium_risk": len(medium_risk),
                "low_risk": len(low_risk),
                "components_analyzed": len(self.components),
                "data_flows_analyzed": len(self.data_flows)
            },
            "threats_by_category": threats_by_category,
            "high_risk_threats": [
                {"id": t.id, "description": t.description, "score": t.risk_score}
                for t in high_risk
            ],
            "recommended_priorities": self._prioritize_mitigations()
        }
 
    def _prioritize_mitigations(self) -> List[Dict]:
        """Prioritize mitigations based on impact."""
        mitigation_impact = {}
 
        for threat in self.threats:
            for mitigation in threat.mitigations:
                if mitigation not in mitigation_impact:
                    mitigation_impact[mitigation] = {
                        "count": 0,
                        "total_risk_reduced": 0,
                        "threats": []
                    }
                mitigation_impact[mitigation]["count"] += 1
                mitigation_impact[mitigation]["total_risk_reduced"] += threat.risk_score
                mitigation_impact[mitigation]["threats"].append(threat.id)
 
        prioritized = [
            {
                "mitigation": m,
                "threats_addressed": data["count"],
                "risk_reduction": round(data["total_risk_reduced"], 1),
                "threat_ids": data["threats"]
            }
            for m, data in mitigation_impact.items()
        ]
 
        return sorted(prioritized, key=lambda x: x["risk_reduction"], reverse=True)[:10]
 
 
# Example usage
def analyze_web_application():
    analyzer = STRIDEAnalyzer()
 
    # Define components
    analyzer.add_component(Component(
        id="user",
        name="End User",
        type="user",
        trust_level=2,
        data_sensitivity="public",
        authentication_required=False
    ))
 
    analyzer.add_component(Component(
        id="web_app",
        name="Web Application",
        type="web_app",
        trust_level=6,
        data_sensitivity="confidential",
        authentication_required=True
    ))
 
    analyzer.add_component(Component(
        id="api",
        name="Backend API",
        type="api",
        trust_level=7,
        data_sensitivity="confidential",
        authentication_required=True
    ))
 
    analyzer.add_component(Component(
        id="database",
        name="PostgreSQL Database",
        type="database",
        trust_level=9,
        data_sensitivity="restricted"
    ))
 
    # Define data flows
    analyzer.add_data_flow(DataFlow(
        id="flow_1",
        source="user",
        destination="web_app",
        protocol="HTTPS",
        data_type="user_credentials",
        encrypted=True,
        authenticated=False
    ))
 
    analyzer.add_data_flow(DataFlow(
        id="flow_2",
        source="web_app",
        destination="api",
        protocol="HTTPS",
        data_type="api_request",
        encrypted=True,
        authenticated=True
    ))
 
    # Run analysis
    threats = analyzer.analyze()
    report = analyzer.generate_report()
 
    print(json.dumps(report, indent=2))
    return report

CI/CD Integration

Threat Model as Code

# threat-model.yaml
version: "1.0"
application: "ecommerce-platform"
team: "platform-engineering"
 
components:
  - id: web_frontend
    name: "Web Frontend"
    type: web_app
    trust_level: 5
    data_sensitivity: internal
    technologies:
      - React
      - TypeScript
 
  - id: api_gateway
    name: "API Gateway"
    type: api
    trust_level: 6
    data_sensitivity: confidential
    authentication:
      type: OAuth2
      mfa_enabled: true
 
  - id: user_service
    name: "User Service"
    type: api
    trust_level: 7
    data_sensitivity: restricted
 
  - id: payment_service
    name: "Payment Service"
    type: api
    trust_level: 8
    data_sensitivity: restricted
    pci_scope: true
 
  - id: postgres_db
    name: "Primary Database"
    type: database
    trust_level: 9
    data_sensitivity: restricted
    encryption_at_rest: true
 
data_flows:
  - from: web_frontend
    to: api_gateway
    protocol: HTTPS
    data_types:
      - user_input
      - session_token
    encrypted: true
 
  - from: api_gateway
    to: user_service
    protocol: gRPC
    data_types:
      - user_data
      - authentication
    encrypted: true
    mutual_tls: true
 
  - from: api_gateway
    to: payment_service
    protocol: gRPC
    data_types:
      - payment_info
      - pii
    encrypted: true
    mutual_tls: true
 
trust_boundaries:
  - name: "Internet Boundary"
    components:
      - web_frontend
    external: true
 
  - name: "Service Mesh"
    components:
      - api_gateway
      - user_service
      - payment_service
 
  - name: "Data Layer"
    components:
      - postgres_db
 
security_controls:
  authentication:
    - OAuth2 with PKCE
    - JWT validation
    - API key management
 
  authorization:
    - RBAC
    - Policy-based access control
 
  encryption:
    - TLS 1.3 minimum
    - AES-256 at rest
 
  monitoring:
    - Centralized logging
    - SIEM integration
    - Anomaly detection

GitHub Actions Threat Analysis

# .github/workflows/threat-model.yml
name: Threat Model Analysis
 
on:
  pull_request:
    paths:
      - 'threat-model.yaml'
      - 'src/**'
      - 'infrastructure/**'
  push:
    branches: [main]
  schedule:
    - cron: '0 6 * * 1'  # Weekly Monday 6 AM
 
jobs:
  analyze-threats:
    runs-on: ubuntu-latest
 
    steps:
      - uses: actions/checkout@v4
 
      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.11'
 
      - name: Install dependencies
        run: |
          pip install pyyaml jsonschema
 
      - name: Validate threat model schema
        run: |
          python scripts/validate_threat_model.py threat-model.yaml
 
      - name: Run STRIDE analysis
        id: stride
        run: |
          python scripts/stride_analyzer.py threat-model.yaml > threat-report.json
 
          # Extract summary for PR comment
          HIGH_RISK=$(jq '.summary.high_risk' threat-report.json)
          MEDIUM_RISK=$(jq '.summary.medium_risk' threat-report.json)
          TOTAL=$(jq '.summary.total_threats' threat-report.json)
 
          echo "high_risk=$HIGH_RISK" >> $GITHUB_OUTPUT
          echo "medium_risk=$MEDIUM_RISK" >> $GITHUB_OUTPUT
          echo "total=$TOTAL" >> $GITHUB_OUTPUT
 
      - name: Check threat thresholds
        run: |
          HIGH_RISK=${{ steps.stride.outputs.high_risk }}
          if [ "$HIGH_RISK" -gt 5 ]; then
            echo "::error::Too many high-risk threats identified ($HIGH_RISK)"
            exit 1
          fi
 
      - name: Upload threat report
        uses: actions/upload-artifact@v4
        with:
          name: threat-report
          path: threat-report.json
 
      - name: Comment on PR
        if: github.event_name == 'pull_request'
        uses: actions/github-script@v7
        with:
          script: |
            const fs = require('fs');
            const report = JSON.parse(fs.readFileSync('threat-report.json', 'utf8'));
 
            const body = `## 🔒 Threat Model Analysis
 
            | Metric | Count |
            |--------|-------|
            | Total Threats | ${report.summary.total_threats} |
            | 🔴 High Risk | ${report.summary.high_risk} |
            | 🟡 Medium Risk | ${report.summary.medium_risk} |
            | 🟢 Low Risk | ${report.summary.low_risk} |
 
            ### Top Recommended Mitigations
            ${report.recommended_priorities.slice(0, 5).map((m, i) =>
              `${i + 1}. **${m.mitigation}** - Addresses ${m.threats_addressed} threats`
            ).join('\n')}
 
            <details>
            <summary>High Risk Threats</summary>
 
            ${report.high_risk_threats.map(t =>
              `- **${t.id}** (Score: ${t.score}): ${t.description}`
            ).join('\n')}
            </details>
            `;
 
            github.rest.issues.createComment({
              issue_number: context.issue.number,
              owner: context.repo.owner,
              repo: context.repo.repo,
              body: body
            });
 
  compare-baseline:
    runs-on: ubuntu-latest
    needs: analyze-threats
    if: github.event_name == 'pull_request'
 
    steps:
      - uses: actions/checkout@v4
        with:
          ref: main
 
      - name: Get baseline threat report
        run: |
          python scripts/stride_analyzer.py threat-model.yaml > baseline-report.json || echo '{"summary":{"total_threats":0}}' > baseline-report.json
 
      - uses: actions/checkout@v4
 
      - name: Download current report
        uses: actions/download-artifact@v4
        with:
          name: threat-report
 
      - name: Compare reports
        run: |
          python scripts/compare_threats.py baseline-report.json threat-report.json

Comparison Script

# scripts/compare_threats.py
import json
import sys
 
def compare_reports(baseline_path: str, current_path: str):
    with open(baseline_path) as f:
        baseline = json.load(f)
 
    with open(current_path) as f:
        current = json.load(f)
 
    baseline_total = baseline.get('summary', {}).get('total_threats', 0)
    current_total = current.get('summary', {}).get('total_threats', 0)
 
    baseline_high = baseline.get('summary', {}).get('high_risk', 0)
    current_high = current.get('summary', {}).get('high_risk', 0)
 
    diff_total = current_total - baseline_total
    diff_high = current_high - baseline_high
 
    print(f"Threat Comparison:")
    print(f"  Total: {baseline_total} -> {current_total} ({'+' if diff_total > 0 else ''}{diff_total})")
    print(f"  High Risk: {baseline_high} -> {current_high} ({'+' if diff_high > 0 else ''}{diff_high})")
 
    # Identify new threats
    baseline_ids = set()
    for category in baseline.get('threats_by_category', {}).values():
        for threat in category:
            baseline_ids.add(threat['id'])
 
    new_threats = []
    for category in current.get('threats_by_category', {}).values():
        for threat in category:
            if threat['id'] not in baseline_ids:
                new_threats.append(threat)
 
    if new_threats:
        print(f"\nNew Threats Identified ({len(new_threats)}):")
        for threat in new_threats:
            print(f"  - {threat['id']}: {threat['description'][:60]}...")
 
    # Fail if high-risk threats increased
    if diff_high > 0:
        print(f"\n❌ High-risk threats increased by {diff_high}")
        sys.exit(1)
 
    print("\n✅ No increase in high-risk threats")
 
if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: compare_threats.py <baseline.json> <current.json>")
        sys.exit(1)
 
    compare_reports(sys.argv[1], sys.argv[2])

Attack Tree Generation

from dataclasses import dataclass, field
from typing import List, Dict, Optional
import json
 
@dataclass
class AttackNode:
    id: str
    description: str
    node_type: str  # goal, sub_goal, attack, defense
    probability: float = 0.0
    cost: float = 0.0
    children: List['AttackNode'] = field(default_factory=list)
    operator: str = "OR"  # AND, OR
 
class AttackTreeGenerator:
    def __init__(self, threat_report: Dict):
        self.threats = threat_report
        self.node_counter = 0
 
    def _generate_node_id(self) -> str:
        self.node_counter += 1
        return f"ATK-{self.node_counter:04d}"
 
    def generate_tree(self, goal: str) -> AttackNode:
        """Generate attack tree for a specific goal."""
        root = AttackNode(
            id=self._generate_node_id(),
            description=goal,
            node_type="goal",
            operator="OR"
        )
 
        # Group related threats
        related_threats = self._find_related_threats(goal)
 
        for category, threats in related_threats.items():
            sub_goal = AttackNode(
                id=self._generate_node_id(),
                description=f"Exploit via {category}",
                node_type="sub_goal",
                operator="OR"
            )
 
            for threat in threats:
                attack = AttackNode(
                    id=self._generate_node_id(),
                    description=threat['description'],
                    node_type="attack",
                    probability=threat['risk_score'] / 10,
                    cost=self._estimate_attack_cost(threat)
                )
 
                # Add defense nodes
                for mitigation in threat['mitigations'][:3]:
                    defense = AttackNode(
                        id=self._generate_node_id(),
                        description=mitigation,
                        node_type="defense"
                    )
                    attack.children.append(defense)
 
                sub_goal.children.append(attack)
 
            if sub_goal.children:
                root.children.append(sub_goal)
 
        return root
 
    def _find_related_threats(self, goal: str) -> Dict[str, List]:
        """Find threats related to the attack goal."""
        goal_lower = goal.lower()
        related = {}
 
        keywords = {
            "data breach": ["Information Disclosure", "Tampering"],
            "account takeover": ["Spoofing", "Elevation of Privilege"],
            "service disruption": ["Denial of Service"],
            "unauthorized access": ["Spoofing", "Elevation of Privilege"],
            "data manipulation": ["Tampering", "Repudiation"]
        }
 
        target_categories = []
        for key, categories in keywords.items():
            if key in goal_lower:
                target_categories.extend(categories)
 
        if not target_categories:
            target_categories = list(self.threats.get('threats_by_category', {}).keys())
 
        for category in target_categories:
            if category in self.threats.get('threats_by_category', {}):
                related[category] = self.threats['threats_by_category'][category]
 
        return related
 
    def _estimate_attack_cost(self, threat: Dict) -> float:
        """Estimate cost/effort to execute attack."""
        base_cost = 5.0
 
        # Higher risk often means easier to exploit
        risk_adjustment = (10 - threat['risk_score']) * 0.5
 
        return base_cost + risk_adjustment
 
    def to_mermaid(self, root: AttackNode) -> str:
        """Convert attack tree to Mermaid diagram."""
        lines = ["graph TD"]
 
        def process_node(node: AttackNode, parent_id: Optional[str] = None):
            node_label = f"{node.id}[\"{node.description[:40]}...\"]" if len(node.description) > 40 else f"{node.id}[\"{node.description}\"]"
 
            # Style based on node type
            styles = {
                "goal": ":::goal",
                "sub_goal": ":::subgoal",
                "attack": ":::attack",
                "defense": ":::defense"
            }
 
            lines.append(f"    {node_label}{styles.get(node.node_type, '')}")
 
            if parent_id:
                connector = "-->|AND|" if node.operator == "AND" else "-->"
                lines.append(f"    {parent_id} {connector} {node.id}")
 
            for child in node.children:
                process_node(child, node.id)
 
        process_node(root)
 
        # Add styles
        lines.extend([
            "",
            "    classDef goal fill:#ff6b6b,stroke:#333,stroke-width:2px",
            "    classDef subgoal fill:#ffd93d,stroke:#333",
            "    classDef attack fill:#ff8c42,stroke:#333",
            "    classDef defense fill:#6bcb77,stroke:#333"
        ])
 
        return "\n".join(lines)
 
 
# Generate attack trees from threat report
def generate_attack_trees(threat_report: Dict):
    generator = AttackTreeGenerator(threat_report)
 
    attack_goals = [
        "Achieve data breach of customer information",
        "Gain unauthorized administrative access",
        "Disrupt service availability"
    ]
 
    trees = {}
    for goal in attack_goals:
        tree = generator.generate_tree(goal)
        trees[goal] = {
            "tree": tree,
            "mermaid": generator.to_mermaid(tree)
        }
 
    return trees

Continuous Threat Validation

import subprocess
import json
from typing import List, Dict
 
class ThreatValidator:
    def __init__(self, threat_model_path: str):
        with open(threat_model_path) as f:
            self.model = json.load(f) if threat_model_path.endswith('.json') else {}
 
    def validate_mitigations(self) -> List[Dict]:
        """Validate that mitigations are implemented."""
        results = []
 
        validation_checks = {
            "Implement multi-factor authentication": self._check_mfa,
            "Use HTTPS/TLS": self._check_tls,
            "Implement rate limiting": self._check_rate_limiting,
            "Enable audit logging": self._check_audit_logging,
            "Encrypt data at rest": self._check_encryption_at_rest
        }
 
        for mitigation, check_func in validation_checks.items():
            try:
                is_implemented, details = check_func()
                results.append({
                    "mitigation": mitigation,
                    "implemented": is_implemented,
                    "details": details
                })
            except Exception as e:
                results.append({
                    "mitigation": mitigation,
                    "implemented": False,
                    "details": f"Check failed: {str(e)}"
                })
 
        return results
 
    def _check_mfa(self) -> tuple:
        """Check if MFA is configured."""
        # Check for MFA configuration in auth provider
        try:
            result = subprocess.run(
                ["grep", "-r", "mfa", "src/auth/", "--include=*.ts"],
                capture_output=True, text=True
            )
            has_mfa = bool(result.stdout)
            return has_mfa, "MFA configuration found" if has_mfa else "No MFA configuration detected"
        except:
            return False, "Could not verify MFA configuration"
 
    def _check_tls(self) -> tuple:
        """Check TLS configuration."""
        # Check nginx/ingress TLS config
        try:
            result = subprocess.run(
                ["grep", "-r", "ssl_certificate", "infrastructure/"],
                capture_output=True, text=True
            )
            has_tls = bool(result.stdout)
            return has_tls, "TLS certificates configured" if has_tls else "No TLS configuration found"
        except:
            return False, "Could not verify TLS configuration"
 
    def _check_rate_limiting(self) -> tuple:
        """Check rate limiting implementation."""
        try:
            result = subprocess.run(
                ["grep", "-r", "rateLimit", "src/", "--include=*.ts"],
                capture_output=True, text=True
            )
            has_rate_limit = bool(result.stdout)
            return has_rate_limit, "Rate limiting implemented" if has_rate_limit else "No rate limiting found"
        except:
            return False, "Could not verify rate limiting"
 
    def _check_audit_logging(self) -> tuple:
        """Check audit logging configuration."""
        try:
            result = subprocess.run(
                ["grep", "-r", "audit", "src/", "--include=*.ts"],
                capture_output=True, text=True
            )
            has_audit = bool(result.stdout)
            return has_audit, "Audit logging found" if has_audit else "No audit logging detected"
        except:
            return False, "Could not verify audit logging"
 
    def _check_encryption_at_rest(self) -> tuple:
        """Check encryption at rest configuration."""
        try:
            # Check database encryption settings
            result = subprocess.run(
                ["grep", "-r", "encrypt", "infrastructure/database/"],
                capture_output=True, text=True
            )
            has_encryption = bool(result.stdout)
            return has_encryption, "Encryption configured" if has_encryption else "No encryption config found"
        except:
            return False, "Could not verify encryption"
 
    def generate_validation_report(self) -> Dict:
        """Generate full validation report."""
        validations = self.validate_mitigations()
 
        implemented = [v for v in validations if v['implemented']]
        not_implemented = [v for v in validations if not v['implemented']]
 
        return {
            "summary": {
                "total_checks": len(validations),
                "implemented": len(implemented),
                "not_implemented": len(not_implemented),
                "coverage": f"{len(implemented) / len(validations) * 100:.1f}%"
            },
            "implemented_mitigations": implemented,
            "missing_mitigations": not_implemented,
            "recommendations": [
                f"Implement: {v['mitigation']}" for v in not_implemented
            ]
        }

Best Practices

Threat Modeling Integration

  1. Start early: Begin threat modeling during design phase
  2. Iterate continuously: Update models as architecture evolves
  3. Automate validation: Check mitigations in CI/CD
  4. Track metrics: Monitor threat counts over time
  5. Involve stakeholders: Include developers, security, and operations

STRIDE Coverage

  • Ensure all components are analyzed for each STRIDE category
  • Focus on trust boundaries and data flows
  • Document assumptions and accepted risks
  • Review mitigations for effectiveness

Integrating threat modeling into DevSecOps pipelines ensures security is considered throughout the development lifecycle, not just at deployment.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.