Federated learning enables training machine learning models across decentralized data sources without exposing raw data. However, this architecture introduces unique security challenges. This guide covers implementing secure federated learning systems with practical defenses.
Federated Learning Architecture
Core System Design
# federated_learning_server.py
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Callable
import numpy as np
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import hashlib
import secrets
@dataclass
class ClientUpdate:
client_id: str
model_weights: Dict[str, np.ndarray]
num_samples: int
round_number: int
signature: bytes
metadata: Dict = field(default_factory=dict)
@dataclass
class FederatedRound:
round_number: int
participating_clients: List[str]
global_weights: Dict[str, np.ndarray]
aggregation_method: str
timestamp: str
class SecureFederatedServer:
"""Secure federated learning server with privacy guarantees"""
def __init__(
self,
model_architecture: Dict,
min_clients: int = 3,
aggregation_method: str = "fedavg"
):
self.model_arch = model_architecture
self.min_clients = min_clients
self.aggregation_method = aggregation_method
self.global_weights = self._initialize_weights()
self.registered_clients = {}
self.round_history = []
self.current_round = 0
def _initialize_weights(self) -> Dict[str, np.ndarray]:
"""Initialize global model weights"""
weights = {}
for layer_name, shape in self.model_arch.items():
# Xavier initialization
fan_in = shape[0] if len(shape) > 0 else 1
fan_out = shape[1] if len(shape) > 1 else 1
limit = np.sqrt(6 / (fan_in + fan_out))
weights[layer_name] = np.random.uniform(-limit, limit, shape)
return weights
def register_client(self, client_id: str, public_key: bytes) -> str:
"""Register client with public key for authentication"""
if client_id in self.registered_clients:
raise ValueError(f"Client {client_id} already registered")
# Generate client token
token = secrets.token_hex(32)
self.registered_clients[client_id] = {
"public_key": public_key,
"token": hashlib.sha256(token.encode()).hexdigest(),
"rounds_participated": 0,
"reputation_score": 1.0
}
return token
def start_round(self) -> Dict:
"""Start new training round"""
self.current_round += 1
return {
"round_number": self.current_round,
"global_weights": self.global_weights,
"min_samples": 100,
"deadline_seconds": 300
}
def aggregate_updates(
self,
updates: List[ClientUpdate]
) -> Dict[str, np.ndarray]:
"""Securely aggregate client updates"""
if len(updates) < self.min_clients:
raise ValueError(
f"Need at least {self.min_clients} clients, got {len(updates)}"
)
# Verify all updates
verified_updates = []
for update in updates:
if self._verify_update(update):
verified_updates.append(update)
if len(verified_updates) < self.min_clients:
raise ValueError("Not enough verified updates")
# Apply aggregation method
if self.aggregation_method == "fedavg":
aggregated = self._fedavg(verified_updates)
elif self.aggregation_method == "median":
aggregated = self._coordinate_median(verified_updates)
elif self.aggregation_method == "trimmed_mean":
aggregated = self._trimmed_mean(verified_updates, trim_ratio=0.1)
else:
raise ValueError(f"Unknown aggregation: {self.aggregation_method}")
# Update global weights
self.global_weights = aggregated
# Record round
self.round_history.append(FederatedRound(
round_number=self.current_round,
participating_clients=[u.client_id for u in verified_updates],
global_weights=aggregated,
aggregation_method=self.aggregation_method,
timestamp=datetime.utcnow().isoformat()
))
return aggregated
def _verify_update(self, update: ClientUpdate) -> bool:
"""Verify client update authenticity and validity"""
# Check client is registered
if update.client_id not in self.registered_clients:
return False
# Check round number
if update.round_number != self.current_round:
return False
# Verify signature (simplified)
# In production, use proper cryptographic verification
# Check weight shapes match
for layer, weights in update.model_weights.items():
if layer not in self.model_arch:
return False
if weights.shape != tuple(self.model_arch[layer]):
return False
return True
def _fedavg(self, updates: List[ClientUpdate]) -> Dict[str, np.ndarray]:
"""Federated averaging aggregation"""
total_samples = sum(u.num_samples for u in updates)
aggregated = {}
for layer in self.global_weights.keys():
weighted_sum = np.zeros_like(self.global_weights[layer])
for update in updates:
weight = update.num_samples / total_samples
weighted_sum += weight * update.model_weights[layer]
aggregated[layer] = weighted_sum
return aggregated
def _coordinate_median(
self,
updates: List[ClientUpdate]
) -> Dict[str, np.ndarray]:
"""Coordinate-wise median (Byzantine-robust)"""
aggregated = {}
for layer in self.global_weights.keys():
# Stack all updates for this layer
stacked = np.stack([u.model_weights[layer] for u in updates])
# Take median along client axis
aggregated[layer] = np.median(stacked, axis=0)
return aggregated
def _trimmed_mean(
self,
updates: List[ClientUpdate],
trim_ratio: float = 0.1
) -> Dict[str, np.ndarray]:
"""Trimmed mean aggregation (removes outliers)"""
from scipy import stats
aggregated = {}
n_trim = int(len(updates) * trim_ratio)
for layer in self.global_weights.keys():
stacked = np.stack([u.model_weights[layer] for u in updates])
# Trimmed mean along client axis
aggregated[layer] = stats.trim_mean(stacked, trim_ratio, axis=0)
return aggregatedDifferential Privacy Implementation
Privacy-Preserving Gradient Clipping
# differential_privacy.py
import numpy as np
from typing import Dict, Tuple
from dataclasses import dataclass
@dataclass
class PrivacyBudget:
epsilon: float # Privacy parameter
delta: float # Failure probability
consumed_epsilon: float = 0.0
max_rounds: int = 100
class DifferentialPrivacy:
"""Implement differential privacy for federated learning"""
def __init__(
self,
epsilon: float = 1.0,
delta: float = 1e-5,
clip_norm: float = 1.0,
noise_multiplier: float = 1.1
):
self.epsilon = epsilon
self.delta = delta
self.clip_norm = clip_norm
self.noise_multiplier = noise_multiplier
def clip_gradients(
self,
gradients: Dict[str, np.ndarray]
) -> Dict[str, np.ndarray]:
"""Clip gradients to bounded sensitivity"""
# Compute global L2 norm
total_norm = 0.0
for grad in gradients.values():
total_norm += np.sum(grad ** 2)
total_norm = np.sqrt(total_norm)
# Clip if necessary
clip_factor = min(1.0, self.clip_norm / (total_norm + 1e-6))
clipped = {}
for name, grad in gradients.items():
clipped[name] = grad * clip_factor
return clipped
def add_noise(
self,
gradients: Dict[str, np.ndarray],
num_samples: int
) -> Dict[str, np.ndarray]:
"""Add calibrated Gaussian noise for DP"""
# Compute noise scale
noise_scale = (
self.clip_norm * self.noise_multiplier / num_samples
)
noisy = {}
for name, grad in gradients.items():
noise = np.random.normal(0, noise_scale, grad.shape)
noisy[name] = grad + noise
return noisy
def compute_privacy_spent(
self,
num_rounds: int,
sampling_rate: float
) -> Tuple[float, float]:
"""Compute privacy budget spent using RDP accountant"""
from scipy.special import comb
# Simplified RDP to (ε, δ)-DP conversion
# In production, use tensorflow-privacy or opacus
alpha = 1 + 1 / (self.noise_multiplier ** 2)
rdp = alpha * num_rounds * sampling_rate ** 2 / (2 * self.noise_multiplier ** 2)
# Convert RDP to (ε, δ)-DP
epsilon = rdp + np.log(1 / self.delta) / (alpha - 1)
return epsilon, self.delta
def privatize_update(
self,
model_update: Dict[str, np.ndarray],
num_samples: int
) -> Dict[str, np.ndarray]:
"""Apply full DP pipeline to model update"""
# Step 1: Clip gradients
clipped = self.clip_gradients(model_update)
# Step 2: Add noise
noisy = self.add_noise(clipped, num_samples)
return noisy
class LocalDifferentialPrivacy:
"""Local differential privacy for individual data points"""
def __init__(self, epsilon: float = 1.0):
self.epsilon = epsilon
def randomized_response(self, bit: bool) -> bool:
"""Randomized response mechanism for binary data"""
p = np.exp(self.epsilon) / (1 + np.exp(self.epsilon))
if np.random.random() < p:
return bit
else:
return not bit
def laplace_mechanism(self, value: float, sensitivity: float) -> float:
"""Laplace mechanism for numeric data"""
scale = sensitivity / self.epsilon
noise = np.random.laplace(0, scale)
return value + noise
def privatize_histogram(
self,
data: np.ndarray,
num_bins: int
) -> np.ndarray:
"""Create private histogram with LDP"""
# Each user reports one bin with probability
p = np.exp(self.epsilon) / (np.exp(self.epsilon) + num_bins - 1)
q = 1 / (np.exp(self.epsilon) + num_bins - 1)
# Simulate responses
histogram = np.zeros(num_bins)
for value in data:
true_bin = int(value * num_bins) % num_bins
# Randomized response
if np.random.random() < p:
histogram[true_bin] += 1
else:
# Report random other bin
other_bins = [i for i in range(num_bins) if i != true_bin]
random_bin = np.random.choice(other_bins)
histogram[random_bin] += 1
# Unbias the histogram
n = len(data)
unbiased = (histogram - n * q) / (p - q)
return np.maximum(unbiased, 0)Secure Aggregation Protocol
Cryptographic Secure Aggregation
# secure_aggregation.py
import numpy as np
from cryptography.hazmat.primitives.asymmetric import x25519
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import hashlib
import os
from typing import Dict, List, Tuple
class SecureAggregation:
"""Secure aggregation using pairwise masking"""
def __init__(self, num_clients: int):
self.num_clients = num_clients
self.client_keys = {}
self.shared_secrets = {}
def generate_keypair(self, client_id: str) -> Tuple[bytes, bytes]:
"""Generate key pair for client"""
private_key = x25519.X25519PrivateKey.generate()
public_key = private_key.public_key()
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption()
)
public_bytes = public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
self.client_keys[client_id] = {
"private": private_key,
"public": public_bytes
}
return private_bytes, public_bytes
def compute_shared_secrets(
self,
client_id: str,
other_public_keys: Dict[str, bytes]
) -> Dict[str, bytes]:
"""Compute shared secrets with other clients"""
private_key = self.client_keys[client_id]["private"]
secrets = {}
for other_id, public_bytes in other_public_keys.items():
if other_id == client_id:
continue
other_public = x25519.X25519PublicKey.from_public_bytes(public_bytes)
shared = private_key.exchange(other_public)
# Derive mask seed from shared secret
secrets[other_id] = hashlib.sha256(shared).digest()
self.shared_secrets[client_id] = secrets
return secrets
def generate_mask(
self,
seed: bytes,
shape: tuple,
client_id: str,
other_id: str
) -> np.ndarray:
"""Generate pseudorandom mask from seed"""
# Use seed to initialize RNG
combined = f"{min(client_id, other_id)}:{max(client_id, other_id)}"
full_seed = hashlib.sha256(seed + combined.encode()).digest()
rng = np.random.default_rng(
int.from_bytes(full_seed[:8], 'big')
)
mask = rng.standard_normal(shape)
# Determine sign based on client ordering
if client_id < other_id:
return mask
else:
return -mask
def mask_update(
self,
client_id: str,
update: Dict[str, np.ndarray]
) -> Dict[str, np.ndarray]:
"""Mask client update with pairwise masks"""
masked = {}
secrets = self.shared_secrets.get(client_id, {})
for layer, weights in update.items():
masked_weights = weights.copy()
# Add masks from shared secrets
for other_id, seed in secrets.items():
mask = self.generate_mask(
seed, weights.shape, client_id, other_id
)
masked_weights += mask
masked[layer] = masked_weights
return masked
def aggregate_masked(
self,
masked_updates: Dict[str, Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
"""Aggregate masked updates (masks cancel out)"""
# When all clients participate, masks sum to zero
aggregated = {}
first_client = list(masked_updates.keys())[0]
for layer in masked_updates[first_client].keys():
layer_sum = np.zeros_like(masked_updates[first_client][layer])
for client_id, update in masked_updates.items():
layer_sum += update[layer]
# Average
aggregated[layer] = layer_sum / len(masked_updates)
return aggregated
class SecretSharing:
"""Shamir's Secret Sharing for dropout resilience"""
def __init__(self, threshold: int, num_shares: int):
self.threshold = threshold
self.num_shares = num_shares
self.prime = 2**127 - 1 # Mersenne prime
def split(self, secret: int) -> List[Tuple[int, int]]:
"""Split secret into shares"""
# Generate random polynomial coefficients
coeffs = [secret] + [
np.random.randint(0, self.prime)
for _ in range(self.threshold - 1)
]
# Evaluate polynomial at points 1, 2, ..., num_shares
shares = []
for x in range(1, self.num_shares + 1):
y = 0
for i, coeff in enumerate(coeffs):
y = (y + coeff * pow(x, i, self.prime)) % self.prime
shares.append((x, y))
return shares
def reconstruct(self, shares: List[Tuple[int, int]]) -> int:
"""Reconstruct secret from shares using Lagrange interpolation"""
if len(shares) < self.threshold:
raise ValueError(
f"Need at least {self.threshold} shares, got {len(shares)}"
)
secret = 0
for i, (xi, yi) in enumerate(shares[:self.threshold]):
numerator = 1
denominator = 1
for j, (xj, _) in enumerate(shares[:self.threshold]):
if i != j:
numerator = (numerator * (-xj)) % self.prime
denominator = (denominator * (xi - xj)) % self.prime
# Modular inverse
lagrange = (
yi * numerator * pow(denominator, -1, self.prime)
) % self.prime
secret = (secret + lagrange) % self.prime
return secretByzantine-Robust Aggregation
Defense Against Poisoning Attacks
# byzantine_defense.py
import numpy as np
from typing import List, Dict, Tuple
from sklearn.cluster import DBSCAN
from scipy.spatial.distance import cosine
class ByzantineDefense:
"""Defend against Byzantine/poisoning attacks in FL"""
def __init__(
self,
num_byzantine: int = 0,
detection_method: str = "krum"
):
self.num_byzantine = num_byzantine
self.detection_method = detection_method
def krum(
self,
updates: List[Dict[str, np.ndarray]],
num_select: int = 1
) -> List[int]:
"""Multi-Krum Byzantine-robust selection"""
n = len(updates)
f = self.num_byzantine
# Flatten updates for distance computation
flattened = []
for update in updates:
flat = np.concatenate([v.flatten() for v in update.values()])
flattened.append(flat)
flattened = np.array(flattened)
# Compute pairwise distances
distances = np.zeros((n, n))
for i in range(n):
for j in range(i + 1, n):
dist = np.linalg.norm(flattened[i] - flattened[j])
distances[i, j] = dist
distances[j, i] = dist
# Compute Krum scores
scores = []
for i in range(n):
# Sum of distances to n - f - 2 closest neighbors
sorted_dists = np.sort(distances[i])
score = np.sum(sorted_dists[1:n - f - 1]) # Exclude self
scores.append(score)
# Select indices with lowest scores
selected = np.argsort(scores)[:num_select]
return selected.tolist()
def trimmed_mean(
self,
updates: List[Dict[str, np.ndarray]],
trim_ratio: float = 0.1
) -> Dict[str, np.ndarray]:
"""Coordinate-wise trimmed mean"""
n = len(updates)
trim_count = int(n * trim_ratio)
aggregated = {}
for layer in updates[0].keys():
# Stack updates
stacked = np.stack([u[layer] for u in updates])
# Sort along client axis
sorted_vals = np.sort(stacked, axis=0)
# Trim and average
trimmed = sorted_vals[trim_count:n - trim_count]
aggregated[layer] = np.mean(trimmed, axis=0)
return aggregated
def median(
self,
updates: List[Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
"""Coordinate-wise median (50% Byzantine tolerance)"""
aggregated = {}
for layer in updates[0].keys():
stacked = np.stack([u[layer] for u in updates])
aggregated[layer] = np.median(stacked, axis=0)
return aggregated
def bulyan(
self,
updates: List[Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
"""Bulyan aggregation (Krum + trimmed mean)"""
n = len(updates)
f = self.num_byzantine
# Step 1: Select 2f + 3 updates using multi-Krum
num_select = min(n - 2 * f, 2 * f + 3)
selected_indices = self.krum(updates, num_select)
selected_updates = [updates[i] for i in selected_indices]
# Step 2: Apply trimmed mean to selected updates
return self.trimmed_mean(selected_updates, trim_ratio=0.25)
def fltrust(
self,
updates: List[Dict[str, np.ndarray]],
server_update: Dict[str, np.ndarray]
) -> Dict[str, np.ndarray]:
"""FLTrust - use server's clean update as trust anchor"""
# Flatten for cosine similarity
server_flat = np.concatenate(
[v.flatten() for v in server_update.values()]
)
server_norm = np.linalg.norm(server_flat)
# Compute trust scores
trust_scores = []
for update in updates:
client_flat = np.concatenate(
[v.flatten() for v in update.values()]
)
# Cosine similarity as trust score
similarity = 1 - cosine(server_flat, client_flat)
# ReLU to handle negative similarities
trust_scores.append(max(0, similarity))
# Normalize trust scores
total_trust = sum(trust_scores)
if total_trust == 0:
# Fall back to uniform weights
weights = [1 / len(updates)] * len(updates)
else:
weights = [t / total_trust for t in trust_scores]
# Weighted aggregation with normalized updates
aggregated = {}
for layer in server_update.keys():
weighted_sum = np.zeros_like(server_update[layer])
for i, update in enumerate(updates):
# Normalize update to server update magnitude
client_layer = update[layer]
client_norm = np.linalg.norm(client_layer)
if client_norm > 0:
normalized = client_layer * (server_norm / client_norm)
else:
normalized = client_layer
weighted_sum += weights[i] * normalized
aggregated[layer] = weighted_sum
return aggregated
def detect_anomalies(
self,
updates: List[Dict[str, np.ndarray]]
) -> List[int]:
"""Detect anomalous updates using clustering"""
# Flatten updates
flattened = []
for update in updates:
flat = np.concatenate([v.flatten() for v in update.values()])
flattened.append(flat)
flattened = np.array(flattened)
# DBSCAN clustering
clustering = DBSCAN(eps=0.5, min_samples=2).fit(flattened)
# Identify outliers (label -1)
anomalies = np.where(clustering.labels_ == -1)[0]
return anomalies.tolist()Client-Side Security
Secure Client Implementation
# secure_client.py
import numpy as np
from typing import Dict, Optional
import hashlib
import hmac
class SecureFederatedClient:
"""Secure federated learning client"""
def __init__(
self,
client_id: str,
model: 'NeuralNetwork',
privacy_budget: float = 1.0
):
self.client_id = client_id
self.model = model
self.dp = DifferentialPrivacy(epsilon=privacy_budget)
self.secure_agg = None
self.server_public_key = None
def initialize_security(
self,
server_public_key: bytes,
other_clients: Dict[str, bytes]
):
"""Initialize security primitives"""
self.server_public_key = server_public_key
# Setup secure aggregation
self.secure_agg = SecureAggregation(len(other_clients) + 1)
_, public_key = self.secure_agg.generate_keypair(self.client_id)
# Compute shared secrets
self.secure_agg.compute_shared_secrets(
self.client_id, other_clients
)
return public_key
def train_local(
self,
data: np.ndarray,
labels: np.ndarray,
epochs: int = 1,
batch_size: int = 32
) -> Dict[str, np.ndarray]:
"""Train model on local data"""
initial_weights = self.model.get_weights()
# Local training
for epoch in range(epochs):
indices = np.random.permutation(len(data))
for i in range(0, len(data), batch_size):
batch_idx = indices[i:i + batch_size]
batch_x = data[batch_idx]
batch_y = labels[batch_idx]
self.model.train_step(batch_x, batch_y)
# Compute update (new weights - old weights)
new_weights = self.model.get_weights()
update = {}
for layer in initial_weights:
update[layer] = new_weights[layer] - initial_weights[layer]
return update
def prepare_update(
self,
update: Dict[str, np.ndarray],
num_samples: int,
round_number: int,
use_dp: bool = True,
use_secure_agg: bool = True
) -> ClientUpdate:
"""Prepare update for submission"""
processed = update
# Apply differential privacy
if use_dp:
processed = self.dp.privatize_update(processed, num_samples)
# Apply secure aggregation mask
if use_secure_agg and self.secure_agg:
processed = self.secure_agg.mask_update(
self.client_id, processed
)
# Sign update
signature = self._sign_update(processed, round_number)
return ClientUpdate(
client_id=self.client_id,
model_weights=processed,
num_samples=num_samples,
round_number=round_number,
signature=signature,
metadata={
"dp_applied": use_dp,
"secure_agg_applied": use_secure_agg
}
)
def _sign_update(
self,
update: Dict[str, np.ndarray],
round_number: int
) -> bytes:
"""Sign update for authentication"""
# Serialize update
serialized = b""
for layer in sorted(update.keys()):
serialized += update[layer].tobytes()
# Add round number
serialized += round_number.to_bytes(8, 'big')
# HMAC signature (simplified - use proper PKI in production)
return hmac.new(
self.client_id.encode(),
serialized,
hashlib.sha256
).digest()
def apply_global_update(
self,
global_weights: Dict[str, np.ndarray]
):
"""Apply received global weights"""
self.model.set_weights(global_weights)Summary
Secure federated learning requires multiple layers of protection:
- Differential privacy: Bound information leakage per participant
- Secure aggregation: Prevent server from seeing individual updates
- Byzantine defenses: Robust aggregation against poisoning
- Authentication: Verify client identities and update integrity
Implement these defenses based on your threat model - not all deployments need every protection, but understanding the tradeoffs is essential for building privacy-preserving ML systems.