Adversarial machine learning represents one of the most significant security challenges in AI deployment. Attackers can manipulate ML systems in ways that are imperceptible to humans but devastating to model performance. This guide explores attack vectors and provides practical defense implementations.
Understanding Adversarial Attack Categories
Attack Taxonomy
# adversarial_taxonomy.py
from enum import Enum
from dataclasses import dataclass
from typing import List, Optional
class AttackPhase(Enum):
TRAINING = "training" # Data poisoning, backdoor attacks
INFERENCE = "inference" # Evasion, adversarial examples
MODEL = "model" # Extraction, inversion
class AttackKnowledge(Enum):
WHITE_BOX = "white_box" # Full model access
BLACK_BOX = "black_box" # Query access only
GRAY_BOX = "gray_box" # Partial knowledge
class AttackGoal(Enum):
MISCLASSIFICATION = "misclassification"
TARGETED = "targeted" # Specific wrong output
AVAILABILITY = "availability" # Deny service
PRIVACY = "privacy" # Extract information
@dataclass
class AdversarialAttack:
name: str
phase: AttackPhase
knowledge: AttackKnowledge
goal: AttackGoal
perturbation_budget: Optional[float]
defense_difficulty: str
# Common attack types
ATTACK_CATALOG = [
AdversarialAttack(
name="FGSM",
phase=AttackPhase.INFERENCE,
knowledge=AttackKnowledge.WHITE_BOX,
goal=AttackGoal.MISCLASSIFICATION,
perturbation_budget=0.3,
defense_difficulty="medium"
),
AdversarialAttack(
name="PGD",
phase=AttackPhase.INFERENCE,
knowledge=AttackKnowledge.WHITE_BOX,
goal=AttackGoal.TARGETED,
perturbation_budget=0.1,
defense_difficulty="high"
),
AdversarialAttack(
name="Data Poisoning",
phase=AttackPhase.TRAINING,
knowledge=AttackKnowledge.BLACK_BOX,
goal=AttackGoal.MISCLASSIFICATION,
perturbation_budget=None,
defense_difficulty="high"
),
AdversarialAttack(
name="Model Extraction",
phase=AttackPhase.MODEL,
knowledge=AttackKnowledge.BLACK_BOX,
goal=AttackGoal.PRIVACY,
perturbation_budget=None,
defense_difficulty="high"
)
]Evasion Attacks and Defenses
Fast Gradient Sign Method (FGSM) Implementation
# fgsm_attack.py
import torch
import torch.nn as nn
import numpy as np
class FGSMAttack:
"""Fast Gradient Sign Method adversarial attack"""
def __init__(self, model: nn.Module, epsilon: float = 0.3):
self.model = model
self.epsilon = epsilon
self.model.eval()
def generate(
self,
x: torch.Tensor,
y: torch.Tensor,
targeted: bool = False,
target_class: int = None
) -> torch.Tensor:
"""Generate adversarial examples"""
x = x.clone().detach().requires_grad_(True)
# Forward pass
outputs = self.model(x)
loss_fn = nn.CrossEntropyLoss()
if targeted:
# Targeted attack: minimize loss toward target
target = torch.full_like(y, target_class)
loss = loss_fn(outputs, target)
else:
# Untargeted: maximize loss from true label
loss = loss_fn(outputs, y)
# Backward pass
self.model.zero_grad()
loss.backward()
# Generate perturbation
sign_grad = x.grad.sign()
if targeted:
# Move toward target (negative gradient)
x_adv = x - self.epsilon * sign_grad
else:
# Move away from true label (positive gradient)
x_adv = x + self.epsilon * sign_grad
# Clamp to valid range
x_adv = torch.clamp(x_adv, 0, 1)
return x_adv
def evaluate_attack(
self,
dataloader: torch.utils.data.DataLoader
) -> dict:
"""Evaluate attack success rate"""
correct_clean = 0
correct_adv = 0
total = 0
for x, y in dataloader:
# Clean prediction
with torch.no_grad():
clean_pred = self.model(x).argmax(dim=1)
correct_clean += (clean_pred == y).sum().item()
# Generate adversarial
x_adv = self.generate(x, y)
# Adversarial prediction
with torch.no_grad():
adv_pred = self.model(x_adv).argmax(dim=1)
correct_adv += (adv_pred == y).sum().item()
total += y.size(0)
return {
"clean_accuracy": correct_clean / total,
"adversarial_accuracy": correct_adv / total,
"attack_success_rate": (correct_clean - correct_adv) / correct_clean
}Projected Gradient Descent (PGD) Attack
# pgd_attack.py
import torch
import torch.nn as nn
class PGDAttack:
"""Projected Gradient Descent - stronger iterative attack"""
def __init__(
self,
model: nn.Module,
epsilon: float = 0.3,
alpha: float = 0.01,
iterations: int = 40,
random_start: bool = True
):
self.model = model
self.epsilon = epsilon
self.alpha = alpha
self.iterations = iterations
self.random_start = random_start
def generate(
self,
x: torch.Tensor,
y: torch.Tensor,
targeted: bool = False,
target_class: int = None
) -> torch.Tensor:
"""Generate PGD adversarial examples"""
x_orig = x.clone().detach()
# Random initialization within epsilon ball
if self.random_start:
x_adv = x + torch.empty_like(x).uniform_(-self.epsilon, self.epsilon)
x_adv = torch.clamp(x_adv, 0, 1)
else:
x_adv = x.clone()
for _ in range(self.iterations):
x_adv = x_adv.clone().detach().requires_grad_(True)
outputs = self.model(x_adv)
loss_fn = nn.CrossEntropyLoss()
if targeted:
target = torch.full_like(y, target_class)
loss = loss_fn(outputs, target)
else:
loss = loss_fn(outputs, y)
self.model.zero_grad()
loss.backward()
# Gradient step
with torch.no_grad():
if targeted:
x_adv = x_adv - self.alpha * x_adv.grad.sign()
else:
x_adv = x_adv + self.alpha * x_adv.grad.sign()
# Project back to epsilon ball around original
perturbation = x_adv - x_orig
perturbation = torch.clamp(perturbation, -self.epsilon, self.epsilon)
x_adv = torch.clamp(x_orig + perturbation, 0, 1)
return x_advAdversarial Training Defense
# adversarial_training.py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class AdversarialTrainer:
"""Train models with adversarial examples for robustness"""
def __init__(
self,
model: nn.Module,
attack_method: str = "pgd",
epsilon: float = 0.3,
mix_ratio: float = 0.5
):
self.model = model
self.epsilon = epsilon
self.mix_ratio = mix_ratio
if attack_method == "fgsm":
self.attack = FGSMAttack(model, epsilon)
else:
self.attack = PGDAttack(model, epsilon)
def train_epoch(
self,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module
) -> dict:
"""Train one epoch with adversarial examples"""
self.model.train()
total_loss = 0
clean_correct = 0
adv_correct = 0
total = 0
for batch_idx, (x, y) in enumerate(dataloader):
# Decide which samples get adversarial perturbation
batch_size = x.size(0)
num_adv = int(batch_size * self.mix_ratio)
# Split batch
x_clean = x[num_adv:]
y_clean = y[num_adv:]
x_for_adv = x[:num_adv]
y_for_adv = y[:num_adv]
# Generate adversarial examples
self.model.eval()
x_adv = self.attack.generate(x_for_adv, y_for_adv)
self.model.train()
# Combine clean and adversarial
x_combined = torch.cat([x_adv, x_clean], dim=0)
y_combined = torch.cat([y_for_adv, y_clean], dim=0)
# Forward pass
optimizer.zero_grad()
outputs = self.model(x_combined)
loss = criterion(outputs, y_combined)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
predictions = outputs.argmax(dim=1)
adv_correct += (predictions[:num_adv] == y_for_adv).sum().item()
clean_correct += (predictions[num_adv:] == y_clean).sum().item()
total += batch_size
return {
"loss": total_loss / len(dataloader),
"clean_accuracy": clean_correct / (total * (1 - self.mix_ratio)),
"adversarial_accuracy": adv_correct / (total * self.mix_ratio)
}Input Validation and Sanitization
Feature Squeezing Defense
# feature_squeezing.py
import torch
import numpy as np
from scipy.ndimage import median_filter
class FeatureSqueezer:
"""Detect adversarial examples through input transformation"""
def __init__(
self,
model: nn.Module,
bit_depth: int = 4,
median_filter_size: int = 2,
threshold: float = 0.1
):
self.model = model
self.bit_depth = bit_depth
self.median_filter_size = median_filter_size
self.threshold = threshold
def bit_depth_reduction(self, x: torch.Tensor) -> torch.Tensor:
"""Reduce color bit depth"""
levels = 2 ** self.bit_depth
return torch.round(x * levels) / levels
def spatial_smoothing(self, x: torch.Tensor) -> torch.Tensor:
"""Apply median filter"""
x_np = x.numpy()
smoothed = np.zeros_like(x_np)
for i in range(x_np.shape[0]):
for c in range(x_np.shape[1]):
smoothed[i, c] = median_filter(
x_np[i, c],
size=self.median_filter_size
)
return torch.from_numpy(smoothed).float()
def detect(self, x: torch.Tensor) -> tuple:
"""Detect adversarial examples"""
self.model.eval()
with torch.no_grad():
# Original prediction
pred_orig = torch.softmax(self.model(x), dim=1)
# Bit depth reduction prediction
x_bit = self.bit_depth_reduction(x)
pred_bit = torch.softmax(self.model(x_bit), dim=1)
# Spatial smoothing prediction
x_smooth = self.spatial_smoothing(x)
pred_smooth = torch.softmax(self.model(x_smooth), dim=1)
# Compute L1 distances
dist_bit = torch.abs(pred_orig - pred_bit).sum(dim=1)
dist_smooth = torch.abs(pred_orig - pred_smooth).sum(dim=1)
# Maximum distance
max_dist = torch.max(dist_bit, dist_smooth)
# Detection decision
is_adversarial = max_dist > self.threshold
return is_adversarial, {
"bit_depth_distance": dist_bit,
"smoothing_distance": dist_smooth,
"max_distance": max_dist
}
def filter_adversarial(
self,
x: torch.Tensor
) -> tuple:
"""Filter out detected adversarial examples"""
is_adv, distances = self.detect(x)
# Return only clean samples
clean_mask = ~is_adv
x_clean = x[clean_mask]
return x_clean, {
"original_count": x.size(0),
"clean_count": x_clean.size(0),
"filtered_count": is_adv.sum().item()
}Certified Defense with Randomized Smoothing
# randomized_smoothing.py
import torch
import torch.nn as nn
from scipy.stats import norm
import numpy as np
class RandomizedSmoothing:
"""Certified defense through randomized smoothing"""
def __init__(
self,
model: nn.Module,
sigma: float = 0.5,
num_samples: int = 100,
alpha: float = 0.001
):
self.model = model
self.sigma = sigma
self.num_samples = num_samples
self.alpha = alpha
def certify(
self,
x: torch.Tensor,
num_samples: int = None
) -> tuple:
"""Certify prediction with radius"""
if num_samples is None:
num_samples = self.num_samples
self.model.eval()
# Sample predictions
counts = self._sample_predictions(x, num_samples)
# Get top class and runner-up
top_class = counts.argmax()
top_count = counts[top_class]
# Statistical test for certification
p_lower = self._lower_confidence_bound(top_count, num_samples)
if p_lower > 0.5:
# Certified - compute radius
radius = self.sigma * norm.ppf(p_lower)
return top_class, radius
else:
# Cannot certify
return -1, 0.0
def _sample_predictions(
self,
x: torch.Tensor,
num_samples: int
) -> np.ndarray:
"""Get prediction counts under noise"""
counts = np.zeros(self.model.num_classes)
batch_size = min(100, num_samples)
num_batches = (num_samples + batch_size - 1) // batch_size
with torch.no_grad():
for _ in range(num_batches):
# Add Gaussian noise
noise = torch.randn_like(x) * self.sigma
x_noisy = x + noise
# Get predictions
outputs = self.model(x_noisy)
predictions = outputs.argmax(dim=1)
# Count
for pred in predictions.numpy():
counts[pred] += 1
return counts
def _lower_confidence_bound(
self,
k: int,
n: int
) -> float:
"""Compute lower bound of binomial proportion"""
from scipy.stats import binom
return binom.ppf(self.alpha, n, k/n) / n
def predict_and_certify_batch(
self,
dataloader: torch.utils.data.DataLoader
) -> dict:
"""Certify predictions on entire dataset"""
certified_correct = 0
total = 0
certified_radii = []
for x, y in dataloader:
for i in range(x.size(0)):
pred_class, radius = self.certify(x[i:i+1])
if pred_class == y[i].item() and radius > 0:
certified_correct += 1
certified_radii.append(radius)
total += 1
return {
"certified_accuracy": certified_correct / total,
"average_certified_radius": np.mean(certified_radii) if certified_radii else 0,
"total_samples": total
}Model Extraction Defense
Watermarking for Model Ownership
# model_watermarking.py
import torch
import torch.nn as nn
import numpy as np
import hashlib
class ModelWatermark:
"""Embed and verify ownership watermarks in ML models"""
def __init__(
self,
model: nn.Module,
num_watermarks: int = 100,
secret_key: str = "your-secret-key"
):
self.model = model
self.num_watermarks = num_watermarks
self.secret_key = secret_key
# Generate watermark triggers and targets
self.triggers, self.targets = self._generate_watermarks()
def _generate_watermarks(self) -> tuple:
"""Generate watermark trigger-target pairs"""
# Deterministic generation from secret
np.random.seed(
int(hashlib.sha256(self.secret_key.encode()).hexdigest()[:8], 16)
)
# Generate trigger patterns (unusual inputs)
triggers = []
targets = []
for i in range(self.num_watermarks):
# Create random trigger pattern
trigger = np.random.rand(1, 3, 32, 32).astype(np.float32)
triggers.append(torch.from_numpy(trigger))
# Assign target class (cycling through classes)
target = i % 10 # Assuming 10 classes
targets.append(target)
return triggers, targets
def embed_watermark(
self,
optimizer: torch.optim.Optimizer,
epochs: int = 10,
watermark_weight: float = 0.1
):
"""Fine-tune model to embed watermark"""
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
total_loss = 0
for trigger, target in zip(self.triggers, self.targets):
optimizer.zero_grad()
# Forward pass on trigger
output = self.model(trigger)
target_tensor = torch.tensor([target])
# Watermark loss
loss = watermark_weight * criterion(output, target_tensor)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Watermark Epoch {epoch+1}: Loss = {total_loss/self.num_watermarks:.4f}")
def verify_watermark(self, suspect_model: nn.Module) -> dict:
"""Verify if model contains watermark"""
suspect_model.eval()
correct = 0
with torch.no_grad():
for trigger, target in zip(self.triggers, self.targets):
output = suspect_model(trigger)
pred = output.argmax(dim=1).item()
if pred == target:
correct += 1
accuracy = correct / self.num_watermarks
# Statistical significance test
# Under random guessing, expected accuracy is ~10% (for 10 classes)
# Threshold for positive detection
threshold = 0.5
return {
"watermark_accuracy": accuracy,
"is_watermarked": accuracy > threshold,
"confidence": min(accuracy / threshold, 1.0),
"num_verified": correct,
"num_total": self.num_watermarks
}Query Rate Limiting and Anomaly Detection
# query_protection.py
from datetime import datetime, timedelta
from collections import defaultdict
import numpy as np
from typing import Dict, List
class QueryProtection:
"""Protect against model extraction through query monitoring"""
def __init__(
self,
rate_limit: int = 100, # queries per minute
window_minutes: int = 1,
similarity_threshold: float = 0.8
):
self.rate_limit = rate_limit
self.window = timedelta(minutes=window_minutes)
self.similarity_threshold = similarity_threshold
# Track queries per user
self.query_history: Dict[str, List] = defaultdict(list)
self.query_embeddings: Dict[str, List] = defaultdict(list)
def check_query(
self,
user_id: str,
query_embedding: np.ndarray
) -> dict:
"""Check if query should be allowed"""
current_time = datetime.now()
# Clean old queries
self._cleanup_old_queries(user_id, current_time)
# Check rate limit
recent_count = len(self.query_history[user_id])
if recent_count >= self.rate_limit:
return {
"allowed": False,
"reason": "rate_limit_exceeded",
"retry_after_seconds": 60
}
# Check for systematic querying (extraction attempt)
if self._detect_extraction_pattern(user_id, query_embedding):
return {
"allowed": False,
"reason": "suspicious_pattern_detected",
"message": "Query pattern suggests model extraction attempt"
}
# Record query
self.query_history[user_id].append(current_time)
self.query_embeddings[user_id].append(query_embedding)
return {"allowed": True}
def _cleanup_old_queries(self, user_id: str, current_time: datetime):
"""Remove queries outside the time window"""
cutoff = current_time - self.window
# Filter history
history = self.query_history[user_id]
embeddings = self.query_embeddings[user_id]
valid_indices = [
i for i, t in enumerate(history) if t > cutoff
]
self.query_history[user_id] = [history[i] for i in valid_indices]
self.query_embeddings[user_id] = [embeddings[i] for i in valid_indices]
def _detect_extraction_pattern(
self,
user_id: str,
new_embedding: np.ndarray
) -> bool:
"""Detect systematic query patterns suggesting extraction"""
embeddings = self.query_embeddings[user_id]
if len(embeddings) < 10:
return False
# Check for grid-like pattern (systematic exploration)
recent = embeddings[-20:] # Last 20 queries
# Compute pairwise distances
distances = []
for i in range(len(recent)):
for j in range(i+1, len(recent)):
dist = np.linalg.norm(recent[i] - recent[j])
distances.append(dist)
# Extraction often shows uniform spacing
if len(distances) > 0:
std_dev = np.std(distances)
mean_dist = np.mean(distances)
# Low variance in distances suggests systematic probing
if std_dev / mean_dist < 0.2:
return True
# Check for boundary probing
# (queries clustered around decision boundaries)
# Implementation depends on model-specific heuristics
return False
def get_user_stats(self, user_id: str) -> dict:
"""Get query statistics for user"""
return {
"queries_in_window": len(self.query_history[user_id]),
"rate_limit": self.rate_limit,
"remaining": max(0, self.rate_limit - len(self.query_history[user_id]))
}Comprehensive Defense Pipeline
Production Defense System
# defense_pipeline.py
import torch
from dataclasses import dataclass
from typing import Optional, List
import logging
@dataclass
class DefenseResult:
allowed: bool
prediction: Optional[int]
confidence: Optional[float]
certified_radius: Optional[float]
warnings: List[str]
class AdversarialDefensePipeline:
"""Complete defense pipeline for production ML systems"""
def __init__(
self,
model: torch.nn.Module,
feature_squeezer: FeatureSqueezer,
smoothed_classifier: RandomizedSmoothing,
query_protection: QueryProtection
):
self.model = model
self.squeezer = feature_squeezer
self.smoother = smoothed_classifier
self.query_protection = query_protection
self.logger = logging.getLogger(__name__)
def predict(
self,
x: torch.Tensor,
user_id: str,
require_certification: bool = False
) -> DefenseResult:
"""Make prediction with full defense pipeline"""
warnings = []
# Step 1: Query rate limiting
query_embedding = self._compute_embedding(x)
rate_check = self.query_protection.check_query(user_id, query_embedding)
if not rate_check["allowed"]:
self.logger.warning(f"Query blocked for {user_id}: {rate_check['reason']}")
return DefenseResult(
allowed=False,
prediction=None,
confidence=None,
certified_radius=None,
warnings=[rate_check.get("message", rate_check["reason"])]
)
# Step 2: Adversarial detection
is_adversarial, detection_info = self.squeezer.detect(x)
if is_adversarial.any():
warnings.append("Potential adversarial input detected")
self.logger.warning(f"Adversarial detection triggered: {detection_info}")
# Option: reject or continue with caution
# Here we continue but flag it
# Step 3: Get prediction
if require_certification:
# Use randomized smoothing for certified prediction
pred_class, radius = self.smoother.certify(x)
if pred_class == -1:
return DefenseResult(
allowed=True,
prediction=None,
confidence=None,
certified_radius=0.0,
warnings=warnings + ["Could not certify prediction"]
)
return DefenseResult(
allowed=True,
prediction=pred_class,
confidence=None, # Certification replaces confidence
certified_radius=radius,
warnings=warnings
)
else:
# Standard prediction
self.model.eval()
with torch.no_grad():
output = self.model(x)
probs = torch.softmax(output, dim=1)
confidence, prediction = probs.max(dim=1)
return DefenseResult(
allowed=True,
prediction=prediction.item(),
confidence=confidence.item(),
certified_radius=None,
warnings=warnings
)
def _compute_embedding(self, x: torch.Tensor) -> np.ndarray:
"""Compute query embedding for monitoring"""
# Simple: flatten and subsample
return x.flatten().numpy()[:100]
def get_health_metrics(self) -> dict:
"""Get defense system health metrics"""
return {
"model_loaded": self.model is not None,
"defenses_active": {
"feature_squeezing": True,
"randomized_smoothing": True,
"query_protection": True
},
"configuration": {
"squeeze_bit_depth": self.squeezer.bit_depth,
"smoothing_sigma": self.smoother.sigma,
"rate_limit": self.query_protection.rate_limit
}
}Summary
Defending against adversarial ML attacks requires multiple layers:
- Detection: Feature squeezing, input validation
- Robustness: Adversarial training, certified defenses
- Monitoring: Query rate limiting, pattern detection
- Ownership: Model watermarking for extraction defense
No single defense is sufficient. Combine these techniques based on your threat model and deploy comprehensive monitoring to detect novel attacks. Regular red-teaming and security assessments are essential for maintaining robust ML systems.