Federated learning enables collaborative model training while keeping data decentralized. This guide covers security and privacy implementations for building robust federated learning systems.
Federated Learning Framework
Build a secure federated learning system:
from dataclasses import dataclass
from typing import List, Dict, Optional, Callable
from enum import Enum
import numpy as np
import hashlib
from abc import ABC, abstractmethod
class AggregationStrategy(Enum):
FEDAVG = "federated_averaging"
FEDPROX = "federated_proximal"
SCAFFOLD = "scaffold"
@dataclass
class ModelUpdate:
client_id: str
round_number: int
weights: Dict[str, np.ndarray]
num_samples: int
metrics: Dict[str, float]
signature: str
@dataclass
class GlobalModel:
version: int
weights: Dict[str, np.ndarray]
round_number: int
participating_clients: List[str]
class FederatedLearningServer:
def __init__(
self,
initial_model: Dict[str, np.ndarray],
aggregation_strategy: AggregationStrategy = AggregationStrategy.FEDAVG,
min_clients: int = 3,
client_fraction: float = 0.5
):
self.global_model = GlobalModel(
version=0,
weights=initial_model,
round_number=0,
participating_clients=[]
)
self.aggregation_strategy = aggregation_strategy
self.min_clients = min_clients
self.client_fraction = client_fraction
self.client_updates: List[ModelUpdate] = []
self.registered_clients: Dict[str, Dict] = {}
def register_client(self, client_id: str, public_key: str, metadata: Dict) -> bool:
"""Register a new client with verification."""
if client_id in self.registered_clients:
return False
self.registered_clients[client_id] = {
'public_key': public_key,
'metadata': metadata,
'rounds_participated': 0,
'last_update': None,
'reputation_score': 1.0
}
return True
def select_clients(self, round_number: int) -> List[str]:
"""Select clients for the current round."""
eligible_clients = [
cid for cid, info in self.registered_clients.items()
if info['reputation_score'] > 0.5
]
num_to_select = max(
self.min_clients,
int(len(eligible_clients) * self.client_fraction)
)
# Random selection with reputation weighting
weights = [
self.registered_clients[c]['reputation_score']
for c in eligible_clients
]
weights = np.array(weights) / sum(weights)
selected = np.random.choice(
eligible_clients,
size=min(num_to_select, len(eligible_clients)),
replace=False,
p=weights
)
return list(selected)
def receive_update(self, update: ModelUpdate) -> bool:
"""Receive and validate client update."""
if update.client_id not in self.registered_clients:
return False
if update.round_number != self.global_model.round_number:
return False
# Verify signature
if not self._verify_signature(update):
return False
# Validate update (check for anomalies)
if not self._validate_update(update):
self._penalize_client(update.client_id, 0.1)
return False
self.client_updates.append(update)
return True
def _verify_signature(self, update: ModelUpdate) -> bool:
"""Verify update signature."""
client_info = self.registered_clients.get(update.client_id)
if not client_info:
return False
# Compute expected signature
data = f"{update.client_id}{update.round_number}{update.num_samples}"
expected = hashlib.sha256(data.encode()).hexdigest()
return update.signature == expected
def _validate_update(self, update: ModelUpdate) -> bool:
"""Validate update for anomalies (model poisoning detection)."""
if update.num_samples < 1:
return False
# Check for extreme weight values
for layer_name, weights in update.weights.items():
if np.any(np.isnan(weights)) or np.any(np.isinf(weights)):
return False
# Check for statistical anomalies
global_weights = self.global_model.weights.get(layer_name)
if global_weights is not None:
diff = np.abs(weights - global_weights)
if np.mean(diff) > 10 * np.std(global_weights):
return False
return True
def _penalize_client(self, client_id: str, penalty: float):
"""Reduce client reputation score."""
if client_id in self.registered_clients:
current = self.registered_clients[client_id]['reputation_score']
self.registered_clients[client_id]['reputation_score'] = max(0, current - penalty)
def aggregate(self) -> GlobalModel:
"""Aggregate client updates into new global model."""
if len(self.client_updates) < self.min_clients:
raise ValueError(f"Not enough updates: {len(self.client_updates)} < {self.min_clients}")
if self.aggregation_strategy == AggregationStrategy.FEDAVG:
new_weights = self._federated_averaging()
elif self.aggregation_strategy == AggregationStrategy.FEDPROX:
new_weights = self._federated_proximal()
else:
new_weights = self._federated_averaging()
participating = [u.client_id for u in self.client_updates]
self.global_model = GlobalModel(
version=self.global_model.version + 1,
weights=new_weights,
round_number=self.global_model.round_number + 1,
participating_clients=participating
)
# Update client stats
for client_id in participating:
self.registered_clients[client_id]['rounds_participated'] += 1
# Clear updates for next round
self.client_updates = []
return self.global_model
def _federated_averaging(self) -> Dict[str, np.ndarray]:
"""FedAvg aggregation."""
total_samples = sum(u.num_samples for u in self.client_updates)
new_weights = {}
for layer_name in self.global_model.weights.keys():
weighted_sum = np.zeros_like(self.global_model.weights[layer_name])
for update in self.client_updates:
weight = update.num_samples / total_samples
weighted_sum += weight * update.weights[layer_name]
new_weights[layer_name] = weighted_sum
return new_weights
def _federated_proximal(self, mu: float = 0.01) -> Dict[str, np.ndarray]:
"""FedProx aggregation with proximal term."""
# Similar to FedAvg but clients use proximal term during training
return self._federated_averaging()Differential Privacy Implementation
Add differential privacy guarantees:
from dataclasses import dataclass
from typing import Tuple
import numpy as np
@dataclass
class PrivacyBudget:
epsilon: float
delta: float
consumed_epsilon: float = 0.0
consumed_delta: float = 0.0
class DifferentialPrivacy:
def __init__(
self,
epsilon: float,
delta: float,
max_grad_norm: float = 1.0,
noise_multiplier: float = 1.0
):
self.budget = PrivacyBudget(epsilon=epsilon, delta=delta)
self.max_grad_norm = max_grad_norm
self.noise_multiplier = noise_multiplier
def clip_gradients(self, gradients: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""Clip gradients to bounded sensitivity."""
total_norm = 0.0
for grad in gradients.values():
total_norm += np.sum(grad ** 2)
total_norm = np.sqrt(total_norm)
clip_factor = min(1.0, self.max_grad_norm / (total_norm + 1e-6))
clipped = {}
for name, grad in gradients.items():
clipped[name] = grad * clip_factor
return clipped
def add_noise(
self,
gradients: Dict[str, np.ndarray],
num_samples: int
) -> Dict[str, np.ndarray]:
"""Add calibrated Gaussian noise for DP."""
sigma = self.noise_multiplier * self.max_grad_norm / num_samples
noisy_gradients = {}
for name, grad in gradients.items():
noise = np.random.normal(0, sigma, grad.shape)
noisy_gradients[name] = grad + noise
return noisy_gradients
def apply_dp(
self,
gradients: Dict[str, np.ndarray],
num_samples: int
) -> Tuple[Dict[str, np.ndarray], float]:
"""Apply differential privacy to gradients."""
# Clip gradients
clipped = self.clip_gradients(gradients)
# Add noise
noisy = self.add_noise(clipped, num_samples)
# Compute privacy cost for this step
epsilon_step = self._compute_privacy_cost(num_samples)
self.budget.consumed_epsilon += epsilon_step
return noisy, epsilon_step
def _compute_privacy_cost(self, num_samples: int) -> float:
"""Compute epsilon cost using moments accountant."""
# Simplified privacy accounting
q = 1.0 / num_samples # Sampling probability
sigma = self.noise_multiplier
# RDP to (epsilon, delta)-DP conversion (simplified)
alpha = 2
rdp = alpha * q ** 2 / (2 * sigma ** 2)
epsilon = rdp + np.log(1 / self.budget.delta) / (alpha - 1)
return epsilon
def check_budget(self) -> bool:
"""Check if privacy budget is exhausted."""
return self.budget.consumed_epsilon < self.budget.epsilon
def get_privacy_spent(self) -> Dict:
"""Get current privacy expenditure."""
return {
'epsilon_budget': self.budget.epsilon,
'epsilon_spent': self.budget.consumed_epsilon,
'epsilon_remaining': self.budget.epsilon - self.budget.consumed_epsilon,
'delta': self.budget.delta
}
class LocalDifferentialPrivacy:
"""Client-side local differential privacy."""
def __init__(self, epsilon: float):
self.epsilon = epsilon
def randomized_response(self, bit: bool) -> bool:
"""Randomized response mechanism for single bit."""
p = np.exp(self.epsilon) / (1 + np.exp(self.epsilon))
if np.random.random() < p:
return bit
else:
return not bit
def laplace_mechanism(self, value: float, sensitivity: float) -> float:
"""Laplace mechanism for numeric values."""
scale = sensitivity / self.epsilon
noise = np.random.laplace(0, scale)
return value + noise
def privatize_vector(
self,
vector: np.ndarray,
sensitivity: float
) -> np.ndarray:
"""Apply LDP to a vector."""
scale = sensitivity / self.epsilon
noise = np.random.laplace(0, scale, vector.shape)
return vector + noiseSecure Aggregation Protocol
Implement secure aggregation for privacy:
from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import secrets
@dataclass
class SecretShare:
client_id: str
share_index: int
share_value: np.ndarray
commitment: bytes
class SecureAggregation:
def __init__(self, threshold: int, num_clients: int):
self.threshold = threshold # Minimum clients needed
self.num_clients = num_clients
self.client_keys: Dict[str, bytes] = {}
self.shares: Dict[str, List[SecretShare]] = {}
def generate_client_keys(self, client_id: str) -> Tuple[bytes, bytes]:
"""Generate key pair for a client."""
private_key = secrets.token_bytes(32)
public_key = self._derive_public_key(private_key)
self.client_keys[client_id] = public_key
return private_key, public_key
def _derive_public_key(self, private_key: bytes) -> bytes:
"""Derive public key from private key."""
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b'public_key'
)
return hkdf.derive(private_key)
def create_pairwise_masks(
self,
client_id: str,
private_key: bytes,
other_clients: List[str],
vector_shape: Tuple
) -> Dict[str, np.ndarray]:
"""Create pairwise masks for secure aggregation."""
masks = {}
for other_id in other_clients:
if other_id == client_id:
continue
other_public_key = self.client_keys.get(other_id)
if not other_public_key:
continue
# Derive shared secret
shared_secret = self._derive_shared_secret(private_key, other_public_key)
# Generate deterministic mask from shared secret
mask = self._generate_mask(shared_secret, vector_shape, client_id, other_id)
masks[other_id] = mask
return masks
def _derive_shared_secret(self, private_key: bytes, public_key: bytes) -> bytes:
"""Derive shared secret using ECDH-like mechanism."""
combined = private_key + public_key
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b'shared_secret'
)
return hkdf.derive(combined)
def _generate_mask(
self,
seed: bytes,
shape: Tuple,
client_a: str,
client_b: str
) -> np.ndarray:
"""Generate PRG-based mask."""
# Deterministic ordering
if client_a < client_b:
sign = 1
else:
sign = -1
# Use seed to initialize RNG
seed_int = int.from_bytes(seed[:4], 'big')
rng = np.random.RandomState(seed_int)
mask = rng.standard_normal(shape) * sign
return mask
def mask_update(
self,
update: np.ndarray,
masks: Dict[str, np.ndarray]
) -> np.ndarray:
"""Mask update with pairwise masks."""
masked = update.copy()
for mask in masks.values():
masked += mask
return masked
def aggregate_masked_updates(
self,
masked_updates: Dict[str, np.ndarray],
dropout_clients: List[str] = None
) -> np.ndarray:
"""Aggregate masked updates - masks cancel out."""
if dropout_clients is None:
dropout_clients = []
# If no dropouts, masks cancel perfectly
aggregated = np.zeros_like(list(masked_updates.values())[0])
for client_id, update in masked_updates.items():
if client_id not in dropout_clients:
aggregated += update
return aggregated / len(masked_updates)
def shamir_secret_share(
self,
secret: np.ndarray,
num_shares: int,
threshold: int
) -> List[Tuple[int, np.ndarray]]:
"""Create Shamir secret shares."""
# Generate random coefficients
coefficients = [secret]
for _ in range(threshold - 1):
coefficients.append(np.random.randn(*secret.shape))
shares = []
for i in range(1, num_shares + 1):
share = np.zeros_like(secret)
for j, coef in enumerate(coefficients):
share += coef * (i ** j)
shares.append((i, share))
return shares
def reconstruct_secret(
self,
shares: List[Tuple[int, np.ndarray]],
threshold: int
) -> np.ndarray:
"""Reconstruct secret from Shamir shares."""
if len(shares) < threshold:
raise ValueError(f"Need at least {threshold} shares")
shares = shares[:threshold]
secret = np.zeros_like(shares[0][1])
for i, (xi, yi) in enumerate(shares):
# Lagrange basis polynomial
numerator = 1.0
denominator = 1.0
for j, (xj, _) in enumerate(shares):
if i != j:
numerator *= (0 - xj)
denominator *= (xi - xj)
secret += yi * (numerator / denominator)
return secretModel Poisoning Defense
Implement defenses against model poisoning attacks:
from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np
from scipy import stats
@dataclass
class ClientAnalysis:
client_id: str
is_malicious: bool
anomaly_score: float
detection_method: str
class ModelPoisoningDefense:
def __init__(self, detection_threshold: float = 2.0):
self.detection_threshold = detection_threshold
self.client_history: Dict[str, List[np.ndarray]] = {}
def analyze_updates(
self,
updates: List[ModelUpdate],
global_weights: Dict[str, np.ndarray]
) -> List[ClientAnalysis]:
"""Analyze updates for potential poisoning."""
analyses = []
for update in updates:
# Multiple detection methods
krum_score = self._krum_score(update, updates)
cosine_score = self._cosine_similarity_score(update, global_weights)
magnitude_score = self._magnitude_score(update, global_weights)
history_score = self._history_consistency_score(update)
# Combined anomaly score
anomaly_score = (
0.3 * krum_score +
0.3 * cosine_score +
0.2 * magnitude_score +
0.2 * history_score
)
is_malicious = anomaly_score > self.detection_threshold
detection_method = self._get_primary_detector(
krum_score, cosine_score, magnitude_score, history_score
)
analyses.append(ClientAnalysis(
client_id=update.client_id,
is_malicious=is_malicious,
anomaly_score=anomaly_score,
detection_method=detection_method
))
# Update history
update_vector = self._flatten_weights(update.weights)
if update.client_id not in self.client_history:
self.client_history[update.client_id] = []
self.client_history[update.client_id].append(update_vector)
return analyses
def _krum_score(self, target: ModelUpdate, all_updates: List[ModelUpdate]) -> float:
"""Multi-Krum anomaly score."""
target_vec = self._flatten_weights(target.weights)
distances = []
for update in all_updates:
if update.client_id == target.client_id:
continue
vec = self._flatten_weights(update.weights)
dist = np.linalg.norm(target_vec - vec)
distances.append(dist)
if not distances:
return 0.0
# Score based on sum of nearest neighbors
distances.sort()
n_neighbors = min(len(distances), len(all_updates) - 2)
krum_score = sum(distances[:n_neighbors])
# Normalize
median_score = np.median([
sum(sorted([
np.linalg.norm(self._flatten_weights(u1.weights) - self._flatten_weights(u2.weights))
for u2 in all_updates if u1.client_id != u2.client_id
])[:n_neighbors])
for u1 in all_updates
])
return krum_score / (median_score + 1e-6)
def _cosine_similarity_score(
self,
update: ModelUpdate,
global_weights: Dict[str, np.ndarray]
) -> float:
"""Score based on cosine similarity with global model direction."""
update_vec = self._flatten_weights(update.weights)
global_vec = self._flatten_weights(global_weights)
# Direction from global to update
direction = update_vec - global_vec
# Expected direction (simplified - use historical average)
norm = np.linalg.norm(direction)
if norm < 1e-6:
return 0.0
# Low similarity with expected direction = anomaly
# For simplicity, check if update moves model significantly
cos_sim = np.dot(direction, global_vec) / (norm * np.linalg.norm(global_vec) + 1e-6)
# Very negative similarity is suspicious
return max(0, -cos_sim)
def _magnitude_score(
self,
update: ModelUpdate,
global_weights: Dict[str, np.ndarray]
) -> float:
"""Score based on update magnitude."""
update_vec = self._flatten_weights(update.weights)
global_vec = self._flatten_weights(global_weights)
diff = update_vec - global_vec
magnitude = np.linalg.norm(diff)
# Z-score based on expected magnitude
expected_magnitude = np.linalg.norm(global_vec) * 0.1 # Heuristic
std_magnitude = expected_magnitude * 0.5
z_score = abs(magnitude - expected_magnitude) / (std_magnitude + 1e-6)
return z_score
def _history_consistency_score(self, update: ModelUpdate) -> float:
"""Score based on consistency with client's history."""
history = self.client_history.get(update.client_id, [])
if len(history) < 2:
return 0.0
current = self._flatten_weights(update.weights)
# Check if current update is consistent with history
historical_changes = []
for i in range(1, len(history)):
change = np.linalg.norm(history[i] - history[i-1])
historical_changes.append(change)
if not historical_changes:
return 0.0
current_change = np.linalg.norm(current - history[-1])
mean_change = np.mean(historical_changes)
std_change = np.std(historical_changes) + 1e-6
z_score = abs(current_change - mean_change) / std_change
return z_score
def _flatten_weights(self, weights: Dict[str, np.ndarray]) -> np.ndarray:
"""Flatten weight dict to single vector."""
return np.concatenate([w.flatten() for w in weights.values()])
def _get_primary_detector(self, *scores) -> str:
"""Get which detector contributed most to anomaly."""
methods = ['krum', 'cosine', 'magnitude', 'history']
max_idx = np.argmax(scores)
return methods[max_idx]
def robust_aggregation(
self,
updates: List[ModelUpdate],
global_weights: Dict[str, np.ndarray],
method: str = 'trimmed_mean'
) -> Dict[str, np.ndarray]:
"""Perform robust aggregation resistant to poisoning."""
if method == 'trimmed_mean':
return self._trimmed_mean_aggregation(updates, trim_ratio=0.1)
elif method == 'median':
return self._median_aggregation(updates)
elif method == 'krum':
return self._krum_aggregation(updates)
else:
raise ValueError(f"Unknown method: {method}")
def _trimmed_mean_aggregation(
self,
updates: List[ModelUpdate],
trim_ratio: float
) -> Dict[str, np.ndarray]:
"""Trimmed mean aggregation."""
n_trim = int(len(updates) * trim_ratio)
aggregated = {}
for layer_name in updates[0].weights.keys():
layer_updates = np.array([u.weights[layer_name] for u in updates])
# Sort and trim along client axis
sorted_updates = np.sort(layer_updates, axis=0)
if n_trim > 0:
trimmed = sorted_updates[n_trim:-n_trim]
else:
trimmed = sorted_updates
aggregated[layer_name] = np.mean(trimmed, axis=0)
return aggregated
def _median_aggregation(self, updates: List[ModelUpdate]) -> Dict[str, np.ndarray]:
"""Coordinate-wise median aggregation."""
aggregated = {}
for layer_name in updates[0].weights.keys():
layer_updates = np.array([u.weights[layer_name] for u in updates])
aggregated[layer_name] = np.median(layer_updates, axis=0)
return aggregated
def _krum_aggregation(self, updates: List[ModelUpdate]) -> Dict[str, np.ndarray]:
"""Select update with minimum Krum score."""
min_score = float('inf')
best_update = None
for update in updates:
score = self._krum_score(update, updates)
if score < min_score:
min_score = score
best_update = update
return best_update.weights if best_update else updates[0].weightsConclusion
Secure federated learning requires multiple layers of protection including differential privacy, secure aggregation, and model poisoning defenses. Implement client verification and reputation systems to track trustworthiness. Use differential privacy to provide formal privacy guarantees. Deploy secure aggregation protocols to prevent the server from seeing individual updates. Add robust aggregation methods to defend against poisoning attacks. Remember that privacy and security in federated learning involve tradeoffs with model utility - carefully calibrate parameters based on your specific requirements.