AI Security

Adversarial Machine Learning: Attack Vectors and Defense Strategies

DeviDevs Team
12 min read
#adversarial ML#AI security#machine learning#cybersecurity#model robustness

Adversarial machine learning represents one of the most significant security challenges in AI deployment. Attackers can manipulate ML systems in ways that are imperceptible to humans but devastating to model performance. This guide explores attack vectors and provides practical defense implementations.

Understanding Adversarial Attack Categories

Attack Taxonomy

# adversarial_taxonomy.py
from enum import Enum
from dataclasses import dataclass
from typing import List, Optional
 
class AttackPhase(Enum):
    TRAINING = "training"  # Data poisoning, backdoor attacks
    INFERENCE = "inference"  # Evasion, adversarial examples
    MODEL = "model"  # Extraction, inversion
 
class AttackKnowledge(Enum):
    WHITE_BOX = "white_box"  # Full model access
    BLACK_BOX = "black_box"  # Query access only
    GRAY_BOX = "gray_box"  # Partial knowledge
 
class AttackGoal(Enum):
    MISCLASSIFICATION = "misclassification"
    TARGETED = "targeted"  # Specific wrong output
    AVAILABILITY = "availability"  # Deny service
    PRIVACY = "privacy"  # Extract information
 
@dataclass
class AdversarialAttack:
    name: str
    phase: AttackPhase
    knowledge: AttackKnowledge
    goal: AttackGoal
    perturbation_budget: Optional[float]
    defense_difficulty: str
 
# Common attack types
ATTACK_CATALOG = [
    AdversarialAttack(
        name="FGSM",
        phase=AttackPhase.INFERENCE,
        knowledge=AttackKnowledge.WHITE_BOX,
        goal=AttackGoal.MISCLASSIFICATION,
        perturbation_budget=0.3,
        defense_difficulty="medium"
    ),
    AdversarialAttack(
        name="PGD",
        phase=AttackPhase.INFERENCE,
        knowledge=AttackKnowledge.WHITE_BOX,
        goal=AttackGoal.TARGETED,
        perturbation_budget=0.1,
        defense_difficulty="high"
    ),
    AdversarialAttack(
        name="Data Poisoning",
        phase=AttackPhase.TRAINING,
        knowledge=AttackKnowledge.BLACK_BOX,
        goal=AttackGoal.MISCLASSIFICATION,
        perturbation_budget=None,
        defense_difficulty="high"
    ),
    AdversarialAttack(
        name="Model Extraction",
        phase=AttackPhase.MODEL,
        knowledge=AttackKnowledge.BLACK_BOX,
        goal=AttackGoal.PRIVACY,
        perturbation_budget=None,
        defense_difficulty="high"
    )
]

Evasion Attacks and Defenses

Fast Gradient Sign Method (FGSM) Implementation

# fgsm_attack.py
import torch
import torch.nn as nn
import numpy as np
 
class FGSMAttack:
    """Fast Gradient Sign Method adversarial attack"""
 
    def __init__(self, model: nn.Module, epsilon: float = 0.3):
        self.model = model
        self.epsilon = epsilon
        self.model.eval()
 
    def generate(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        targeted: bool = False,
        target_class: int = None
    ) -> torch.Tensor:
        """Generate adversarial examples"""
        x = x.clone().detach().requires_grad_(True)
 
        # Forward pass
        outputs = self.model(x)
        loss_fn = nn.CrossEntropyLoss()
 
        if targeted:
            # Targeted attack: minimize loss toward target
            target = torch.full_like(y, target_class)
            loss = loss_fn(outputs, target)
        else:
            # Untargeted: maximize loss from true label
            loss = loss_fn(outputs, y)
 
        # Backward pass
        self.model.zero_grad()
        loss.backward()
 
        # Generate perturbation
        sign_grad = x.grad.sign()
 
        if targeted:
            # Move toward target (negative gradient)
            x_adv = x - self.epsilon * sign_grad
        else:
            # Move away from true label (positive gradient)
            x_adv = x + self.epsilon * sign_grad
 
        # Clamp to valid range
        x_adv = torch.clamp(x_adv, 0, 1)
 
        return x_adv
 
    def evaluate_attack(
        self,
        dataloader: torch.utils.data.DataLoader
    ) -> dict:
        """Evaluate attack success rate"""
        correct_clean = 0
        correct_adv = 0
        total = 0
 
        for x, y in dataloader:
            # Clean prediction
            with torch.no_grad():
                clean_pred = self.model(x).argmax(dim=1)
                correct_clean += (clean_pred == y).sum().item()
 
            # Generate adversarial
            x_adv = self.generate(x, y)
 
            # Adversarial prediction
            with torch.no_grad():
                adv_pred = self.model(x_adv).argmax(dim=1)
                correct_adv += (adv_pred == y).sum().item()
 
            total += y.size(0)
 
        return {
            "clean_accuracy": correct_clean / total,
            "adversarial_accuracy": correct_adv / total,
            "attack_success_rate": (correct_clean - correct_adv) / correct_clean
        }

Projected Gradient Descent (PGD) Attack

# pgd_attack.py
import torch
import torch.nn as nn
 
class PGDAttack:
    """Projected Gradient Descent - stronger iterative attack"""
 
    def __init__(
        self,
        model: nn.Module,
        epsilon: float = 0.3,
        alpha: float = 0.01,
        iterations: int = 40,
        random_start: bool = True
    ):
        self.model = model
        self.epsilon = epsilon
        self.alpha = alpha
        self.iterations = iterations
        self.random_start = random_start
 
    def generate(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        targeted: bool = False,
        target_class: int = None
    ) -> torch.Tensor:
        """Generate PGD adversarial examples"""
        x_orig = x.clone().detach()
 
        # Random initialization within epsilon ball
        if self.random_start:
            x_adv = x + torch.empty_like(x).uniform_(-self.epsilon, self.epsilon)
            x_adv = torch.clamp(x_adv, 0, 1)
        else:
            x_adv = x.clone()
 
        for _ in range(self.iterations):
            x_adv = x_adv.clone().detach().requires_grad_(True)
 
            outputs = self.model(x_adv)
            loss_fn = nn.CrossEntropyLoss()
 
            if targeted:
                target = torch.full_like(y, target_class)
                loss = loss_fn(outputs, target)
            else:
                loss = loss_fn(outputs, y)
 
            self.model.zero_grad()
            loss.backward()
 
            # Gradient step
            with torch.no_grad():
                if targeted:
                    x_adv = x_adv - self.alpha * x_adv.grad.sign()
                else:
                    x_adv = x_adv + self.alpha * x_adv.grad.sign()
 
                # Project back to epsilon ball around original
                perturbation = x_adv - x_orig
                perturbation = torch.clamp(perturbation, -self.epsilon, self.epsilon)
                x_adv = torch.clamp(x_orig + perturbation, 0, 1)
 
        return x_adv

Adversarial Training Defense

# adversarial_training.py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
 
class AdversarialTrainer:
    """Train models with adversarial examples for robustness"""
 
    def __init__(
        self,
        model: nn.Module,
        attack_method: str = "pgd",
        epsilon: float = 0.3,
        mix_ratio: float = 0.5
    ):
        self.model = model
        self.epsilon = epsilon
        self.mix_ratio = mix_ratio
 
        if attack_method == "fgsm":
            self.attack = FGSMAttack(model, epsilon)
        else:
            self.attack = PGDAttack(model, epsilon)
 
    def train_epoch(
        self,
        dataloader: DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: nn.Module
    ) -> dict:
        """Train one epoch with adversarial examples"""
        self.model.train()
        total_loss = 0
        clean_correct = 0
        adv_correct = 0
        total = 0
 
        for batch_idx, (x, y) in enumerate(dataloader):
            # Decide which samples get adversarial perturbation
            batch_size = x.size(0)
            num_adv = int(batch_size * self.mix_ratio)
 
            # Split batch
            x_clean = x[num_adv:]
            y_clean = y[num_adv:]
            x_for_adv = x[:num_adv]
            y_for_adv = y[:num_adv]
 
            # Generate adversarial examples
            self.model.eval()
            x_adv = self.attack.generate(x_for_adv, y_for_adv)
            self.model.train()
 
            # Combine clean and adversarial
            x_combined = torch.cat([x_adv, x_clean], dim=0)
            y_combined = torch.cat([y_for_adv, y_clean], dim=0)
 
            # Forward pass
            optimizer.zero_grad()
            outputs = self.model(x_combined)
            loss = criterion(outputs, y_combined)
 
            # Backward pass
            loss.backward()
            optimizer.step()
 
            # Statistics
            total_loss += loss.item()
            predictions = outputs.argmax(dim=1)
            adv_correct += (predictions[:num_adv] == y_for_adv).sum().item()
            clean_correct += (predictions[num_adv:] == y_clean).sum().item()
            total += batch_size
 
        return {
            "loss": total_loss / len(dataloader),
            "clean_accuracy": clean_correct / (total * (1 - self.mix_ratio)),
            "adversarial_accuracy": adv_correct / (total * self.mix_ratio)
        }

Input Validation and Sanitization

Feature Squeezing Defense

# feature_squeezing.py
import torch
import numpy as np
from scipy.ndimage import median_filter
 
class FeatureSqueezer:
    """Detect adversarial examples through input transformation"""
 
    def __init__(
        self,
        model: nn.Module,
        bit_depth: int = 4,
        median_filter_size: int = 2,
        threshold: float = 0.1
    ):
        self.model = model
        self.bit_depth = bit_depth
        self.median_filter_size = median_filter_size
        self.threshold = threshold
 
    def bit_depth_reduction(self, x: torch.Tensor) -> torch.Tensor:
        """Reduce color bit depth"""
        levels = 2 ** self.bit_depth
        return torch.round(x * levels) / levels
 
    def spatial_smoothing(self, x: torch.Tensor) -> torch.Tensor:
        """Apply median filter"""
        x_np = x.numpy()
        smoothed = np.zeros_like(x_np)
 
        for i in range(x_np.shape[0]):
            for c in range(x_np.shape[1]):
                smoothed[i, c] = median_filter(
                    x_np[i, c],
                    size=self.median_filter_size
                )
 
        return torch.from_numpy(smoothed).float()
 
    def detect(self, x: torch.Tensor) -> tuple:
        """Detect adversarial examples"""
        self.model.eval()
 
        with torch.no_grad():
            # Original prediction
            pred_orig = torch.softmax(self.model(x), dim=1)
 
            # Bit depth reduction prediction
            x_bit = self.bit_depth_reduction(x)
            pred_bit = torch.softmax(self.model(x_bit), dim=1)
 
            # Spatial smoothing prediction
            x_smooth = self.spatial_smoothing(x)
            pred_smooth = torch.softmax(self.model(x_smooth), dim=1)
 
        # Compute L1 distances
        dist_bit = torch.abs(pred_orig - pred_bit).sum(dim=1)
        dist_smooth = torch.abs(pred_orig - pred_smooth).sum(dim=1)
 
        # Maximum distance
        max_dist = torch.max(dist_bit, dist_smooth)
 
        # Detection decision
        is_adversarial = max_dist > self.threshold
 
        return is_adversarial, {
            "bit_depth_distance": dist_bit,
            "smoothing_distance": dist_smooth,
            "max_distance": max_dist
        }
 
    def filter_adversarial(
        self,
        x: torch.Tensor
    ) -> tuple:
        """Filter out detected adversarial examples"""
        is_adv, distances = self.detect(x)
 
        # Return only clean samples
        clean_mask = ~is_adv
        x_clean = x[clean_mask]
 
        return x_clean, {
            "original_count": x.size(0),
            "clean_count": x_clean.size(0),
            "filtered_count": is_adv.sum().item()
        }

Certified Defense with Randomized Smoothing

# randomized_smoothing.py
import torch
import torch.nn as nn
from scipy.stats import norm
import numpy as np
 
class RandomizedSmoothing:
    """Certified defense through randomized smoothing"""
 
    def __init__(
        self,
        model: nn.Module,
        sigma: float = 0.5,
        num_samples: int = 100,
        alpha: float = 0.001
    ):
        self.model = model
        self.sigma = sigma
        self.num_samples = num_samples
        self.alpha = alpha
 
    def certify(
        self,
        x: torch.Tensor,
        num_samples: int = None
    ) -> tuple:
        """Certify prediction with radius"""
        if num_samples is None:
            num_samples = self.num_samples
 
        self.model.eval()
 
        # Sample predictions
        counts = self._sample_predictions(x, num_samples)
 
        # Get top class and runner-up
        top_class = counts.argmax()
        top_count = counts[top_class]
 
        # Statistical test for certification
        p_lower = self._lower_confidence_bound(top_count, num_samples)
 
        if p_lower > 0.5:
            # Certified - compute radius
            radius = self.sigma * norm.ppf(p_lower)
            return top_class, radius
        else:
            # Cannot certify
            return -1, 0.0
 
    def _sample_predictions(
        self,
        x: torch.Tensor,
        num_samples: int
    ) -> np.ndarray:
        """Get prediction counts under noise"""
        counts = np.zeros(self.model.num_classes)
 
        batch_size = min(100, num_samples)
        num_batches = (num_samples + batch_size - 1) // batch_size
 
        with torch.no_grad():
            for _ in range(num_batches):
                # Add Gaussian noise
                noise = torch.randn_like(x) * self.sigma
                x_noisy = x + noise
 
                # Get predictions
                outputs = self.model(x_noisy)
                predictions = outputs.argmax(dim=1)
 
                # Count
                for pred in predictions.numpy():
                    counts[pred] += 1
 
        return counts
 
    def _lower_confidence_bound(
        self,
        k: int,
        n: int
    ) -> float:
        """Compute lower bound of binomial proportion"""
        from scipy.stats import binom
        return binom.ppf(self.alpha, n, k/n) / n
 
    def predict_and_certify_batch(
        self,
        dataloader: torch.utils.data.DataLoader
    ) -> dict:
        """Certify predictions on entire dataset"""
        certified_correct = 0
        total = 0
        certified_radii = []
 
        for x, y in dataloader:
            for i in range(x.size(0)):
                pred_class, radius = self.certify(x[i:i+1])
 
                if pred_class == y[i].item() and radius > 0:
                    certified_correct += 1
                    certified_radii.append(radius)
 
                total += 1
 
        return {
            "certified_accuracy": certified_correct / total,
            "average_certified_radius": np.mean(certified_radii) if certified_radii else 0,
            "total_samples": total
        }

Model Extraction Defense

Watermarking for Model Ownership

# model_watermarking.py
import torch
import torch.nn as nn
import numpy as np
import hashlib
 
class ModelWatermark:
    """Embed and verify ownership watermarks in ML models"""
 
    def __init__(
        self,
        model: nn.Module,
        num_watermarks: int = 100,
        secret_key: str = "your-secret-key"
    ):
        self.model = model
        self.num_watermarks = num_watermarks
        self.secret_key = secret_key
 
        # Generate watermark triggers and targets
        self.triggers, self.targets = self._generate_watermarks()
 
    def _generate_watermarks(self) -> tuple:
        """Generate watermark trigger-target pairs"""
        # Deterministic generation from secret
        np.random.seed(
            int(hashlib.sha256(self.secret_key.encode()).hexdigest()[:8], 16)
        )
 
        # Generate trigger patterns (unusual inputs)
        triggers = []
        targets = []
 
        for i in range(self.num_watermarks):
            # Create random trigger pattern
            trigger = np.random.rand(1, 3, 32, 32).astype(np.float32)
            triggers.append(torch.from_numpy(trigger))
 
            # Assign target class (cycling through classes)
            target = i % 10  # Assuming 10 classes
            targets.append(target)
 
        return triggers, targets
 
    def embed_watermark(
        self,
        optimizer: torch.optim.Optimizer,
        epochs: int = 10,
        watermark_weight: float = 0.1
    ):
        """Fine-tune model to embed watermark"""
        criterion = nn.CrossEntropyLoss()
 
        for epoch in range(epochs):
            total_loss = 0
 
            for trigger, target in zip(self.triggers, self.targets):
                optimizer.zero_grad()
 
                # Forward pass on trigger
                output = self.model(trigger)
                target_tensor = torch.tensor([target])
 
                # Watermark loss
                loss = watermark_weight * criterion(output, target_tensor)
                loss.backward()
                optimizer.step()
 
                total_loss += loss.item()
 
            print(f"Watermark Epoch {epoch+1}: Loss = {total_loss/self.num_watermarks:.4f}")
 
    def verify_watermark(self, suspect_model: nn.Module) -> dict:
        """Verify if model contains watermark"""
        suspect_model.eval()
        correct = 0
 
        with torch.no_grad():
            for trigger, target in zip(self.triggers, self.targets):
                output = suspect_model(trigger)
                pred = output.argmax(dim=1).item()
 
                if pred == target:
                    correct += 1
 
        accuracy = correct / self.num_watermarks
 
        # Statistical significance test
        # Under random guessing, expected accuracy is ~10% (for 10 classes)
        # Threshold for positive detection
        threshold = 0.5
 
        return {
            "watermark_accuracy": accuracy,
            "is_watermarked": accuracy > threshold,
            "confidence": min(accuracy / threshold, 1.0),
            "num_verified": correct,
            "num_total": self.num_watermarks
        }

Query Rate Limiting and Anomaly Detection

# query_protection.py
from datetime import datetime, timedelta
from collections import defaultdict
import numpy as np
from typing import Dict, List
 
class QueryProtection:
    """Protect against model extraction through query monitoring"""
 
    def __init__(
        self,
        rate_limit: int = 100,  # queries per minute
        window_minutes: int = 1,
        similarity_threshold: float = 0.8
    ):
        self.rate_limit = rate_limit
        self.window = timedelta(minutes=window_minutes)
        self.similarity_threshold = similarity_threshold
 
        # Track queries per user
        self.query_history: Dict[str, List] = defaultdict(list)
        self.query_embeddings: Dict[str, List] = defaultdict(list)
 
    def check_query(
        self,
        user_id: str,
        query_embedding: np.ndarray
    ) -> dict:
        """Check if query should be allowed"""
        current_time = datetime.now()
 
        # Clean old queries
        self._cleanup_old_queries(user_id, current_time)
 
        # Check rate limit
        recent_count = len(self.query_history[user_id])
        if recent_count >= self.rate_limit:
            return {
                "allowed": False,
                "reason": "rate_limit_exceeded",
                "retry_after_seconds": 60
            }
 
        # Check for systematic querying (extraction attempt)
        if self._detect_extraction_pattern(user_id, query_embedding):
            return {
                "allowed": False,
                "reason": "suspicious_pattern_detected",
                "message": "Query pattern suggests model extraction attempt"
            }
 
        # Record query
        self.query_history[user_id].append(current_time)
        self.query_embeddings[user_id].append(query_embedding)
 
        return {"allowed": True}
 
    def _cleanup_old_queries(self, user_id: str, current_time: datetime):
        """Remove queries outside the time window"""
        cutoff = current_time - self.window
 
        # Filter history
        history = self.query_history[user_id]
        embeddings = self.query_embeddings[user_id]
 
        valid_indices = [
            i for i, t in enumerate(history) if t > cutoff
        ]
 
        self.query_history[user_id] = [history[i] for i in valid_indices]
        self.query_embeddings[user_id] = [embeddings[i] for i in valid_indices]
 
    def _detect_extraction_pattern(
        self,
        user_id: str,
        new_embedding: np.ndarray
    ) -> bool:
        """Detect systematic query patterns suggesting extraction"""
        embeddings = self.query_embeddings[user_id]
 
        if len(embeddings) < 10:
            return False
 
        # Check for grid-like pattern (systematic exploration)
        recent = embeddings[-20:]  # Last 20 queries
 
        # Compute pairwise distances
        distances = []
        for i in range(len(recent)):
            for j in range(i+1, len(recent)):
                dist = np.linalg.norm(recent[i] - recent[j])
                distances.append(dist)
 
        # Extraction often shows uniform spacing
        if len(distances) > 0:
            std_dev = np.std(distances)
            mean_dist = np.mean(distances)
 
            # Low variance in distances suggests systematic probing
            if std_dev / mean_dist < 0.2:
                return True
 
        # Check for boundary probing
        # (queries clustered around decision boundaries)
        # Implementation depends on model-specific heuristics
 
        return False
 
    def get_user_stats(self, user_id: str) -> dict:
        """Get query statistics for user"""
        return {
            "queries_in_window": len(self.query_history[user_id]),
            "rate_limit": self.rate_limit,
            "remaining": max(0, self.rate_limit - len(self.query_history[user_id]))
        }

Comprehensive Defense Pipeline

Production Defense System

# defense_pipeline.py
import torch
from dataclasses import dataclass
from typing import Optional, List
import logging
 
@dataclass
class DefenseResult:
    allowed: bool
    prediction: Optional[int]
    confidence: Optional[float]
    certified_radius: Optional[float]
    warnings: List[str]
 
class AdversarialDefensePipeline:
    """Complete defense pipeline for production ML systems"""
 
    def __init__(
        self,
        model: torch.nn.Module,
        feature_squeezer: FeatureSqueezer,
        smoothed_classifier: RandomizedSmoothing,
        query_protection: QueryProtection
    ):
        self.model = model
        self.squeezer = feature_squeezer
        self.smoother = smoothed_classifier
        self.query_protection = query_protection
        self.logger = logging.getLogger(__name__)
 
    def predict(
        self,
        x: torch.Tensor,
        user_id: str,
        require_certification: bool = False
    ) -> DefenseResult:
        """Make prediction with full defense pipeline"""
        warnings = []
 
        # Step 1: Query rate limiting
        query_embedding = self._compute_embedding(x)
        rate_check = self.query_protection.check_query(user_id, query_embedding)
 
        if not rate_check["allowed"]:
            self.logger.warning(f"Query blocked for {user_id}: {rate_check['reason']}")
            return DefenseResult(
                allowed=False,
                prediction=None,
                confidence=None,
                certified_radius=None,
                warnings=[rate_check.get("message", rate_check["reason"])]
            )
 
        # Step 2: Adversarial detection
        is_adversarial, detection_info = self.squeezer.detect(x)
 
        if is_adversarial.any():
            warnings.append("Potential adversarial input detected")
            self.logger.warning(f"Adversarial detection triggered: {detection_info}")
 
            # Option: reject or continue with caution
            # Here we continue but flag it
 
        # Step 3: Get prediction
        if require_certification:
            # Use randomized smoothing for certified prediction
            pred_class, radius = self.smoother.certify(x)
 
            if pred_class == -1:
                return DefenseResult(
                    allowed=True,
                    prediction=None,
                    confidence=None,
                    certified_radius=0.0,
                    warnings=warnings + ["Could not certify prediction"]
                )
 
            return DefenseResult(
                allowed=True,
                prediction=pred_class,
                confidence=None,  # Certification replaces confidence
                certified_radius=radius,
                warnings=warnings
            )
        else:
            # Standard prediction
            self.model.eval()
            with torch.no_grad():
                output = self.model(x)
                probs = torch.softmax(output, dim=1)
                confidence, prediction = probs.max(dim=1)
 
            return DefenseResult(
                allowed=True,
                prediction=prediction.item(),
                confidence=confidence.item(),
                certified_radius=None,
                warnings=warnings
            )
 
    def _compute_embedding(self, x: torch.Tensor) -> np.ndarray:
        """Compute query embedding for monitoring"""
        # Simple: flatten and subsample
        return x.flatten().numpy()[:100]
 
    def get_health_metrics(self) -> dict:
        """Get defense system health metrics"""
        return {
            "model_loaded": self.model is not None,
            "defenses_active": {
                "feature_squeezing": True,
                "randomized_smoothing": True,
                "query_protection": True
            },
            "configuration": {
                "squeeze_bit_depth": self.squeezer.bit_depth,
                "smoothing_sigma": self.smoother.sigma,
                "rate_limit": self.query_protection.rate_limit
            }
        }

Summary

Defending against adversarial ML attacks requires multiple layers:

  1. Detection: Feature squeezing, input validation
  2. Robustness: Adversarial training, certified defenses
  3. Monitoring: Query rate limiting, pattern detection
  4. Ownership: Model watermarking for extraction defense

No single defense is sufficient. Combine these techniques based on your threat model and deploy comprehensive monitoring to detect novel attacks. Regular red-teaming and security assessments are essential for maintaining robust ML systems.

Weekly AI Security & Automation Digest

Get the latest on AI Security, workflow automation, secure integrations, and custom platform development delivered weekly.

No spam. Unsubscribe anytime.