DevSecOps

Cloud Security Posture Management: Automating Cloud Configuration Compliance

DeviDevs Team
10 min read
#CSPM#cloud security#compliance#AWS#Azure#GCP

Cloud Security Posture Management: Automating Cloud Configuration Compliance

Cloud Security Posture Management (CSPM) helps organizations identify and remediate misconfigurations across cloud environments. This guide covers implementing automated CSPM solutions.

CSPM Framework Architecture

Multi-Cloud Security Scanner

from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Any
from enum import Enum
from abc import ABC, abstractmethod
import json
 
class CloudProvider(Enum):
    AWS = "aws"
    AZURE = "azure"
    GCP = "gcp"
 
class RiskLevel(Enum):
    CRITICAL = "critical"
    HIGH = "high"
    MEDIUM = "medium"
    LOW = "low"
    INFO = "info"
 
class ComplianceFramework(Enum):
    CIS = "cis"
    SOC2 = "soc2"
    PCI_DSS = "pci_dss"
    HIPAA = "hipaa"
    NIST = "nist"
    GDPR = "gdpr"
 
@dataclass
class Finding:
    id: str
    title: str
    description: str
    resource_id: str
    resource_type: str
    region: str
    provider: CloudProvider
    risk_level: RiskLevel
    compliance_frameworks: List[ComplianceFramework]
    remediation: str
    evidence: Dict[str, Any]
    discovered_at: datetime = field(default_factory=datetime.utcnow)
 
@dataclass
class SecurityPolicy:
    id: str
    name: str
    description: str
    provider: CloudProvider
    resource_type: str
    check_function: str
    risk_level: RiskLevel
    compliance_frameworks: List[ComplianceFramework]
    remediation_steps: List[str]
    enabled: bool = True
 
class CloudSecurityScanner(ABC):
    """Abstract base class for cloud security scanners."""
 
    @abstractmethod
    def scan(self, policies: List[SecurityPolicy]) -> List[Finding]:
        pass
 
    @abstractmethod
    def get_resources(self, resource_type: str) -> List[Dict]:
        pass
 
 
class AWSSecurityScanner(CloudSecurityScanner):
    """AWS-specific security scanner."""
 
    def __init__(self, session=None):
        self.session = session or self._create_session()
        self.findings: List[Finding] = []
 
    def _create_session(self):
        import boto3
        return boto3.Session()
 
    def scan(self, policies: List[SecurityPolicy]) -> List[Finding]:
        """Scan AWS resources against security policies."""
        self.findings = []
 
        for policy in policies:
            if policy.provider != CloudProvider.AWS or not policy.enabled:
                continue
 
            resources = self.get_resources(policy.resource_type)
            for resource in resources:
                finding = self._evaluate_policy(policy, resource)
                if finding:
                    self.findings.append(finding)
 
        return self.findings
 
    def get_resources(self, resource_type: str) -> List[Dict]:
        """Get AWS resources by type."""
        resource_fetchers = {
            "s3_bucket": self._get_s3_buckets,
            "ec2_instance": self._get_ec2_instances,
            "security_group": self._get_security_groups,
            "iam_user": self._get_iam_users,
            "rds_instance": self._get_rds_instances,
            "lambda_function": self._get_lambda_functions,
            "cloudtrail": self._get_cloudtrails,
            "kms_key": self._get_kms_keys
        }
 
        fetcher = resource_fetchers.get(resource_type)
        if fetcher:
            return fetcher()
        return []
 
    def _get_s3_buckets(self) -> List[Dict]:
        """Get all S3 buckets with configuration."""
        s3 = self.session.client('s3')
        buckets = []
 
        response = s3.list_buckets()
        for bucket in response.get('Buckets', []):
            bucket_name = bucket['Name']
            bucket_info = {
                'name': bucket_name,
                'creation_date': bucket['CreationDate'].isoformat(),
                'region': self._get_bucket_region(bucket_name)
            }
 
            # Get bucket configuration
            try:
                # Encryption
                encryption = s3.get_bucket_encryption(Bucket=bucket_name)
                bucket_info['encryption'] = encryption.get('ServerSideEncryptionConfiguration')
            except s3.exceptions.ClientError:
                bucket_info['encryption'] = None
 
            try:
                # Versioning
                versioning = s3.get_bucket_versioning(Bucket=bucket_name)
                bucket_info['versioning'] = versioning.get('Status')
            except:
                bucket_info['versioning'] = None
 
            try:
                # Public access block
                public_access = s3.get_public_access_block(Bucket=bucket_name)
                bucket_info['public_access_block'] = public_access.get('PublicAccessBlockConfiguration')
            except:
                bucket_info['public_access_block'] = None
 
            try:
                # Logging
                logging_config = s3.get_bucket_logging(Bucket=bucket_name)
                bucket_info['logging'] = logging_config.get('LoggingEnabled')
            except:
                bucket_info['logging'] = None
 
            buckets.append(bucket_info)
 
        return buckets
 
    def _get_bucket_region(self, bucket_name: str) -> str:
        """Get S3 bucket region."""
        s3 = self.session.client('s3')
        try:
            response = s3.get_bucket_location(Bucket=bucket_name)
            return response.get('LocationConstraint') or 'us-east-1'
        except:
            return 'unknown'
 
    def _get_ec2_instances(self) -> List[Dict]:
        """Get all EC2 instances."""
        ec2 = self.session.client('ec2')
        instances = []
 
        paginator = ec2.get_paginator('describe_instances')
        for page in paginator.paginate():
            for reservation in page['Reservations']:
                for instance in reservation['Instances']:
                    instances.append({
                        'id': instance['InstanceId'],
                        'type': instance['InstanceType'],
                        'state': instance['State']['Name'],
                        'vpc_id': instance.get('VpcId'),
                        'subnet_id': instance.get('SubnetId'),
                        'security_groups': instance.get('SecurityGroups', []),
                        'iam_profile': instance.get('IamInstanceProfile'),
                        'public_ip': instance.get('PublicIpAddress'),
                        'metadata_options': instance.get('MetadataOptions', {}),
                        'ebs_optimized': instance.get('EbsOptimized', False),
                        'monitoring': instance.get('Monitoring', {}).get('State')
                    })
 
        return instances
 
    def _get_security_groups(self) -> List[Dict]:
        """Get all security groups."""
        ec2 = self.session.client('ec2')
        groups = []
 
        response = ec2.describe_security_groups()
        for sg in response['SecurityGroups']:
            groups.append({
                'id': sg['GroupId'],
                'name': sg['GroupName'],
                'description': sg['Description'],
                'vpc_id': sg.get('VpcId'),
                'inbound_rules': sg.get('IpPermissions', []),
                'outbound_rules': sg.get('IpPermissionsEgress', [])
            })
 
        return groups
 
    def _get_iam_users(self) -> List[Dict]:
        """Get all IAM users."""
        iam = self.session.client('iam')
        users = []
 
        paginator = iam.get_paginator('list_users')
        for page in paginator.paginate():
            for user in page['Users']:
                user_info = {
                    'name': user['UserName'],
                    'arn': user['Arn'],
                    'created': user['CreateDate'].isoformat(),
                    'password_last_used': user.get('PasswordLastUsed', '').isoformat() if user.get('PasswordLastUsed') else None
                }
 
                # Get MFA devices
                try:
                    mfa_response = iam.list_mfa_devices(UserName=user['UserName'])
                    user_info['mfa_enabled'] = len(mfa_response['MFADevices']) > 0
                except:
                    user_info['mfa_enabled'] = False
 
                # Get access keys
                try:
                    keys_response = iam.list_access_keys(UserName=user['UserName'])
                    user_info['access_keys'] = [
                        {
                            'id': key['AccessKeyId'],
                            'status': key['Status'],
                            'created': key['CreateDate'].isoformat()
                        }
                        for key in keys_response['AccessKeyMetadata']
                    ]
                except:
                    user_info['access_keys'] = []
 
                users.append(user_info)
 
        return users
 
    def _get_rds_instances(self) -> List[Dict]:
        """Get all RDS instances."""
        rds = self.session.client('rds')
        instances = []
 
        paginator = rds.get_paginator('describe_db_instances')
        for page in paginator.paginate():
            for db in page['DBInstances']:
                instances.append({
                    'id': db['DBInstanceIdentifier'],
                    'engine': db['Engine'],
                    'status': db['DBInstanceStatus'],
                    'publicly_accessible': db.get('PubliclyAccessible', False),
                    'storage_encrypted': db.get('StorageEncrypted', False),
                    'multi_az': db.get('MultiAZ', False),
                    'backup_retention': db.get('BackupRetentionPeriod', 0),
                    'deletion_protection': db.get('DeletionProtection', False),
                    'auto_minor_upgrade': db.get('AutoMinorVersionUpgrade', False)
                })
 
        return instances
 
    def _get_lambda_functions(self) -> List[Dict]:
        """Get all Lambda functions."""
        lambda_client = self.session.client('lambda')
        functions = []
 
        paginator = lambda_client.get_paginator('list_functions')
        for page in paginator.paginate():
            for func in page['Functions']:
                functions.append({
                    'name': func['FunctionName'],
                    'runtime': func.get('Runtime'),
                    'role': func['Role'],
                    'vpc_config': func.get('VpcConfig', {}),
                    'environment': func.get('Environment', {}).get('Variables', {}),
                    'kms_key_arn': func.get('KMSKeyArn'),
                    'tracing': func.get('TracingConfig', {}).get('Mode'),
                    'timeout': func.get('Timeout'),
                    'memory_size': func.get('MemorySize')
                })
 
        return functions
 
    def _get_cloudtrails(self) -> List[Dict]:
        """Get all CloudTrail configurations."""
        cloudtrail = self.session.client('cloudtrail')
        trails = []
 
        response = cloudtrail.describe_trails()
        for trail in response['trailList']:
            trail_info = {
                'name': trail['Name'],
                'arn': trail['TrailARN'],
                's3_bucket': trail.get('S3BucketName'),
                'is_multi_region': trail.get('IsMultiRegionTrail', False),
                'log_file_validation': trail.get('LogFileValidationEnabled', False),
                'kms_key_id': trail.get('KmsKeyId'),
                'is_organization_trail': trail.get('IsOrganizationTrail', False)
            }
 
            # Get trail status
            try:
                status = cloudtrail.get_trail_status(Name=trail['Name'])
                trail_info['is_logging'] = status.get('IsLogging', False)
            except:
                trail_info['is_logging'] = False
 
            trails.append(trail_info)
 
        return trails
 
    def _get_kms_keys(self) -> List[Dict]:
        """Get all KMS keys."""
        kms = self.session.client('kms')
        keys = []
 
        paginator = kms.get_paginator('list_keys')
        for page in paginator.paginate():
            for key in page['Keys']:
                try:
                    key_info = kms.describe_key(KeyId=key['KeyId'])['KeyMetadata']
                    rotation = kms.get_key_rotation_status(KeyId=key['KeyId'])
 
                    keys.append({
                        'id': key['KeyId'],
                        'arn': key_info['Arn'],
                        'description': key_info.get('Description', ''),
                        'key_state': key_info['KeyState'],
                        'key_usage': key_info['KeyUsage'],
                        'origin': key_info['Origin'],
                        'rotation_enabled': rotation.get('KeyRotationEnabled', False)
                    })
                except:
                    continue
 
        return keys
 
    def _evaluate_policy(self, policy: SecurityPolicy, resource: Dict) -> Optional[Finding]:
        """Evaluate a security policy against a resource."""
        # Policy check functions
        checks = {
            "s3_encryption_enabled": self._check_s3_encryption,
            "s3_versioning_enabled": self._check_s3_versioning,
            "s3_public_access_blocked": self._check_s3_public_access,
            "s3_logging_enabled": self._check_s3_logging,
            "ec2_imdsv2_required": self._check_ec2_imdsv2,
            "ec2_public_ip": self._check_ec2_public_ip,
            "sg_open_to_world": self._check_sg_open_ports,
            "iam_mfa_enabled": self._check_iam_mfa,
            "iam_access_key_rotation": self._check_iam_key_rotation,
            "rds_public_access": self._check_rds_public,
            "rds_encryption": self._check_rds_encryption,
            "rds_backup_enabled": self._check_rds_backup,
            "cloudtrail_enabled": self._check_cloudtrail_enabled,
            "kms_rotation_enabled": self._check_kms_rotation
        }
 
        check_func = checks.get(policy.check_function)
        if not check_func:
            return None
 
        is_compliant, evidence = check_func(resource)
 
        if not is_compliant:
            return Finding(
                id=f"{policy.id}_{resource.get('id') or resource.get('name')}",
                title=policy.name,
                description=policy.description,
                resource_id=resource.get('id') or resource.get('name'),
                resource_type=policy.resource_type,
                region=resource.get('region', 'global'),
                provider=CloudProvider.AWS,
                risk_level=policy.risk_level,
                compliance_frameworks=policy.compliance_frameworks,
                remediation="\n".join(policy.remediation_steps),
                evidence=evidence
            )
 
        return None
 
    # Check functions
    def _check_s3_encryption(self, bucket: Dict) -> tuple:
        """Check if S3 bucket has encryption enabled."""
        encrypted = bucket.get('encryption') is not None
        return encrypted, {"encryption_config": bucket.get('encryption')}
 
    def _check_s3_versioning(self, bucket: Dict) -> tuple:
        """Check if S3 versioning is enabled."""
        enabled = bucket.get('versioning') == 'Enabled'
        return enabled, {"versioning_status": bucket.get('versioning')}
 
    def _check_s3_public_access(self, bucket: Dict) -> tuple:
        """Check if S3 public access is blocked."""
        config = bucket.get('public_access_block', {})
        all_blocked = all([
            config.get('BlockPublicAcls', False),
            config.get('IgnorePublicAcls', False),
            config.get('BlockPublicPolicy', False),
            config.get('RestrictPublicBuckets', False)
        ]) if config else False
        return all_blocked, {"public_access_block": config}
 
    def _check_s3_logging(self, bucket: Dict) -> tuple:
        """Check if S3 logging is enabled."""
        enabled = bucket.get('logging') is not None
        return enabled, {"logging_config": bucket.get('logging')}
 
    def _check_ec2_imdsv2(self, instance: Dict) -> tuple:
        """Check if EC2 requires IMDSv2."""
        metadata_options = instance.get('metadata_options', {})
        required = metadata_options.get('HttpTokens') == 'required'
        return required, {"metadata_options": metadata_options}
 
    def _check_ec2_public_ip(self, instance: Dict) -> tuple:
        """Check if EC2 has public IP."""
        has_public_ip = instance.get('public_ip') is not None
        # Having public IP is a finding
        return not has_public_ip, {"public_ip": instance.get('public_ip')}
 
    def _check_sg_open_ports(self, sg: Dict) -> tuple:
        """Check for security groups open to the world."""
        dangerous_ports = [22, 3389, 3306, 5432, 1433, 27017]
        findings = []
 
        for rule in sg.get('inbound_rules', []):
            for ip_range in rule.get('IpRanges', []):
                if ip_range.get('CidrIp') == '0.0.0.0/0':
                    from_port = rule.get('FromPort', 0)
                    to_port = rule.get('ToPort', 65535)
                    if from_port in dangerous_ports or to_port in dangerous_ports:
                        findings.append({
                            "port_range": f"{from_port}-{to_port}",
                            "cidr": "0.0.0.0/0"
                        })
 
        return len(findings) == 0, {"open_ports": findings}
 
    def _check_iam_mfa(self, user: Dict) -> tuple:
        """Check if IAM user has MFA enabled."""
        enabled = user.get('mfa_enabled', False)
        return enabled, {"mfa_enabled": enabled}
 
    def _check_iam_key_rotation(self, user: Dict) -> tuple:
        """Check if IAM access keys are rotated."""
        max_age_days = 90
        now = datetime.utcnow()
        old_keys = []
 
        for key in user.get('access_keys', []):
            if key['status'] == 'Active':
                created = datetime.fromisoformat(key['created'].replace('Z', '+00:00'))
                age_days = (now - created.replace(tzinfo=None)).days
                if age_days > max_age_days:
                    old_keys.append({
                        "key_id": key['id'],
                        "age_days": age_days
                    })
 
        return len(old_keys) == 0, {"old_keys": old_keys}
 
    def _check_rds_public(self, db: Dict) -> tuple:
        """Check if RDS is publicly accessible."""
        public = db.get('publicly_accessible', False)
        return not public, {"publicly_accessible": public}
 
    def _check_rds_encryption(self, db: Dict) -> tuple:
        """Check if RDS is encrypted."""
        encrypted = db.get('storage_encrypted', False)
        return encrypted, {"storage_encrypted": encrypted}
 
    def _check_rds_backup(self, db: Dict) -> tuple:
        """Check if RDS has backups enabled."""
        retention = db.get('backup_retention', 0)
        return retention > 0, {"backup_retention_days": retention}
 
    def _check_cloudtrail_enabled(self, trail: Dict) -> tuple:
        """Check if CloudTrail is logging."""
        logging = trail.get('is_logging', False)
        multi_region = trail.get('is_multi_region', False)
        return logging and multi_region, {
            "is_logging": logging,
            "is_multi_region": multi_region
        }
 
    def _check_kms_rotation(self, key: Dict) -> tuple:
        """Check if KMS key rotation is enabled."""
        rotation = key.get('rotation_enabled', False)
        return rotation, {"rotation_enabled": rotation}

Compliance Reporting

class CSPMReporter:
    def __init__(self):
        self.findings: List[Finding] = []
 
    def generate_report(self,
                        findings: List[Finding],
                        framework: Optional[ComplianceFramework] = None) -> Dict:
        """Generate comprehensive CSPM report."""
        self.findings = findings
 
        if framework:
            findings = [f for f in findings if framework in f.compliance_frameworks]
 
        report = {
            "generated_at": datetime.utcnow().isoformat(),
            "framework_filter": framework.value if framework else "all",
            "summary": self._generate_summary(findings),
            "findings_by_risk": self._group_by_risk(findings),
            "findings_by_resource": self._group_by_resource(findings),
            "findings_by_provider": self._group_by_provider(findings),
            "compliance_status": self._calculate_compliance(findings),
            "remediation_priorities": self._prioritize_remediation(findings),
            "trend_data": self._get_trend_data(),
            "detailed_findings": [
                {
                    "id": f.id,
                    "title": f.title,
                    "description": f.description,
                    "resource_id": f.resource_id,
                    "resource_type": f.resource_type,
                    "provider": f.provider.value,
                    "risk_level": f.risk_level.value,
                    "compliance_frameworks": [c.value for c in f.compliance_frameworks],
                    "remediation": f.remediation,
                    "discovered_at": f.discovered_at.isoformat()
                }
                for f in findings
            ]
        }
 
        return report
 
    def _generate_summary(self, findings: List[Finding]) -> Dict:
        """Generate findings summary."""
        return {
            "total_findings": len(findings),
            "critical": len([f for f in findings if f.risk_level == RiskLevel.CRITICAL]),
            "high": len([f for f in findings if f.risk_level == RiskLevel.HIGH]),
            "medium": len([f for f in findings if f.risk_level == RiskLevel.MEDIUM]),
            "low": len([f for f in findings if f.risk_level == RiskLevel.LOW]),
            "unique_resources": len(set(f.resource_id for f in findings)),
            "providers_affected": len(set(f.provider for f in findings))
        }
 
    def _group_by_risk(self, findings: List[Finding]) -> Dict:
        """Group findings by risk level."""
        grouped = {}
        for level in RiskLevel:
            level_findings = [f for f in findings if f.risk_level == level]
            grouped[level.value] = {
                "count": len(level_findings),
                "resources": list(set(f.resource_id for f in level_findings))
            }
        return grouped
 
    def _group_by_resource(self, findings: List[Finding]) -> Dict:
        """Group findings by resource type."""
        grouped = {}
        for finding in findings:
            rt = finding.resource_type
            if rt not in grouped:
                grouped[rt] = {"count": 0, "findings": []}
            grouped[rt]["count"] += 1
            grouped[rt]["findings"].append(finding.id)
        return grouped
 
    def _group_by_provider(self, findings: List[Finding]) -> Dict:
        """Group findings by cloud provider."""
        grouped = {}
        for finding in findings:
            provider = finding.provider.value
            if provider not in grouped:
                grouped[provider] = {"count": 0, "by_risk": {}}
            grouped[provider]["count"] += 1
 
            risk = finding.risk_level.value
            if risk not in grouped[provider]["by_risk"]:
                grouped[provider]["by_risk"][risk] = 0
            grouped[provider]["by_risk"][risk] += 1
 
        return grouped
 
    def _calculate_compliance(self, findings: List[Finding]) -> Dict:
        """Calculate compliance percentage by framework."""
        compliance = {}
 
        for framework in ComplianceFramework:
            framework_findings = [f for f in findings if framework in f.compliance_frameworks]
            total_checks = len(framework_findings) + 100  # Assuming 100 passing checks
            passing = 100  # Placeholder
 
            compliance[framework.value] = {
                "passing": passing,
                "failing": len(framework_findings),
                "percentage": (passing / total_checks * 100) if total_checks > 0 else 100
            }
 
        return compliance
 
    def _prioritize_remediation(self, findings: List[Finding]) -> List[Dict]:
        """Prioritize findings for remediation."""
        # Score findings based on risk and impact
        scored = []
        for finding in findings:
            score = {
                RiskLevel.CRITICAL: 100,
                RiskLevel.HIGH: 75,
                RiskLevel.MEDIUM: 50,
                RiskLevel.LOW: 25
            }.get(finding.risk_level, 0)
 
            # Boost score for compliance-sensitive findings
            if ComplianceFramework.PCI_DSS in finding.compliance_frameworks:
                score += 20
            if ComplianceFramework.HIPAA in finding.compliance_frameworks:
                score += 15
 
            scored.append({
                "finding_id": finding.id,
                "title": finding.title,
                "resource_id": finding.resource_id,
                "risk_level": finding.risk_level.value,
                "priority_score": score,
                "remediation": finding.remediation
            })
 
        return sorted(scored, key=lambda x: x["priority_score"], reverse=True)[:20]
 
    def _get_trend_data(self) -> Dict:
        """Get trend data for dashboard."""
        # Placeholder - would fetch historical data
        return {
            "last_7_days": [],
            "last_30_days": []
        }

Best Practices

CSPM Implementation

  1. Start with critical resources: Focus on publicly accessible resources first
  2. Automate scanning: Run scans continuously, not just periodically
  3. Prioritize remediation: Focus on high-risk findings with clear remediation
  4. Track compliance trends: Monitor improvement over time

Multi-Cloud Considerations

  • Use consistent policies across providers where possible
  • Map provider-specific services to common security controls
  • Maintain separate baselines for each provider's native features
  • Implement unified reporting across all cloud environments

CSPM provides essential visibility into cloud security posture and enables proactive identification of misconfigurations before they become breaches.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.